use super::config::ShardGranularity;
use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{Hash, Hasher};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ShardKey {
pub prefix: String,
pub order: Option<u8>,
}
impl ShardKey {
pub fn new(prefix: impl Into<String>) -> Self {
Self {
prefix: prefix.into(),
order: None,
}
}
pub fn with_order(prefix: impl Into<String>, order: u8) -> Self {
Self {
prefix: prefix.into(),
order: Some(order),
}
}
pub fn from_index(index: usize) -> Self {
Self {
prefix: format!("{:04}", index),
order: None,
}
}
pub fn is_index_based(&self) -> bool {
self.prefix.chars().all(|c| c.is_ascii_digit())
}
pub fn as_index(&self) -> Option<usize> {
if self.is_index_based() {
self.prefix.parse().ok()
} else {
None
}
}
pub fn as_file_stem(&self) -> String {
match self.order {
Some(order) => format!("{}_{}", self.prefix, order),
None => self.prefix.clone(),
}
}
}
impl Hash for ShardKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.prefix.hash(state);
self.order.hash(state);
}
}
impl fmt::Display for ShardKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.order {
Some(order) => write!(f, "{}:{}", self.prefix, order),
None => write!(f, "{}", self.prefix),
}
}
}
pub fn compute_shard_key(ngram: &str, order: u8, granularity: &ShardGranularity) -> ShardKey {
if let ShardGranularity::CpuProportional {
multiplier: _,
minimum: _,
} = granularity
{
let num_shards = granularity.num_shards();
let index = hash_to_shard(ngram, num_shards);
return ShardKey::from_index(index);
}
let prefix_len = granularity.prefix_len_for_order(order);
let first_word = ngram.split('|').next().unwrap_or("");
let prefix: String = first_word
.chars()
.filter(|c| c.is_alphabetic())
.take(prefix_len)
.flat_map(|c| c.to_lowercase())
.collect();
let prefix = if prefix.is_empty() {
"_".repeat(prefix_len)
} else if prefix.len() < prefix_len {
format!("{:a<width$}", prefix, width = prefix_len)
} else {
prefix
};
ShardKey::new(prefix)
}
fn hash_to_shard(ngram: &str, num_shards: usize) -> usize {
let mut hasher = DefaultHasher::new();
ngram.hash(&mut hasher);
(hasher.finish() as usize) % num_shards
}
pub fn shard_key_for_file_prefix(
file_prefix: &str,
order: u8,
granularity: &ShardGranularity,
) -> ShardKey {
if granularity.is_hash_based() {
let num_shards = granularity.num_shards();
let index = hash_to_shard(file_prefix, num_shards);
return ShardKey::from_index(index);
}
let target_len = granularity.prefix_len_for_order(order);
let prefix = file_prefix.to_lowercase();
let prefix = if prefix.len() < target_len {
format!("{:a<width$}", prefix, width = target_len)
} else if prefix.len() > target_len {
prefix[..target_len].to_string()
} else {
prefix
};
ShardKey::new(prefix)
}
pub fn all_shard_keys(granularity: &ShardGranularity, order: u8) -> Vec<ShardKey> {
if granularity.is_hash_based() {
let num_shards = granularity.num_shards();
return (0..num_shards).map(ShardKey::from_index).collect();
}
let prefix_len = granularity.prefix_len_for_order(order);
AllPrefixIter {
prefix_len,
current: None,
}
.map(|prefix| ShardKey::new(prefix))
.collect()
}
struct AllPrefixIter {
prefix_len: usize,
current: Option<Vec<u8>>,
}
impl Iterator for AllPrefixIter {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
match &mut self.current {
None => {
self.current = Some(vec![b'a'; self.prefix_len]);
Some(String::from_utf8(vec![b'a'; self.prefix_len]).unwrap())
}
Some(chars) => {
let mut i = self.prefix_len;
while i > 0 {
i -= 1;
if chars[i] < b'z' {
chars[i] += 1;
for j in (i + 1)..self.prefix_len {
chars[j] = b'a';
}
return Some(String::from_utf8(chars.clone()).unwrap());
}
}
None
}
}
}
}
pub fn ngram_order(ngram: &str) -> u8 {
ngram.split('|').count() as u8
}
pub fn compute_shard_key_from_token(
first_token: &str,
order: u8,
granularity: &ShardGranularity,
) -> ShardKey {
if let ShardGranularity::CpuProportional { .. } = granularity {
let num_shards = granularity.num_shards();
let index = hash_to_shard(first_token, num_shards);
return ShardKey::from_index(index);
}
let prefix_len = granularity.prefix_len_for_order(order);
let prefix: String = first_token
.chars()
.filter(|c| c.is_alphabetic())
.take(prefix_len)
.flat_map(|c| c.to_lowercase())
.collect();
let prefix = if prefix.is_empty() {
"_".repeat(prefix_len)
} else if prefix.len() < prefix_len {
format!("{:a<width$}", prefix, width = prefix_len)
} else {
prefix
};
ShardKey::new(prefix)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_shard_key_first_char() {
let g = ShardGranularity::FirstChar;
let key = compute_shard_key("the|quick|brown", 3, &g);
assert_eq!(key.prefix, "t");
let key = compute_shard_key("APPLE", 1, &g);
assert_eq!(key.prefix, "a");
let key = compute_shard_key("123", 1, &g);
assert_eq!(key.prefix, "_"); }
#[test]
fn test_compute_shard_key_two_char() {
let g = ShardGranularity::TwoChar;
let key = compute_shard_key("the|quick|brown", 3, &g);
assert_eq!(key.prefix, "th");
let key = compute_shard_key("a|b|c", 3, &g);
assert_eq!(key.prefix, "aa");
let key = compute_shard_key("ZEBRA", 1, &g);
assert_eq!(key.prefix, "ze"); }
#[test]
fn test_compute_shard_key_adaptive() {
let g = ShardGranularity::Adaptive;
let key = compute_shard_key("apple", 1, &g);
assert_eq!(key.prefix, "a");
let key = compute_shard_key("apple|pie", 2, &g);
assert_eq!(key.prefix, "ap");
let key = compute_shard_key("the|quick|brown|fox|jumps", 5, &g);
assert_eq!(key.prefix, "th");
}
#[test]
fn test_compute_shard_key_cpu_proportional() {
let g = ShardGranularity::CpuProportional {
multiplier: 2,
minimum: 8,
};
let key = compute_shard_key("the|quick|brown", 3, &g);
assert!(key.is_index_based());
let key2 = compute_shard_key("the|quick|brown", 3, &g);
assert_eq!(key.prefix, key2.prefix);
let key3 = compute_shard_key("apple|pie", 2, &g);
assert!(key3.is_index_based());
}
#[test]
fn test_hash_distribution() {
let g = ShardGranularity::CpuProportional {
multiplier: 1,
minimum: 16, };
let num_shards = g.num_shards();
let mut shard_counts = vec![0usize; num_shards];
let test_ngrams = [
"the",
"quick",
"brown",
"fox",
"jumps",
"over",
"lazy",
"dog",
"apple",
"banana",
"cherry",
"date",
"elderberry",
"fig",
"grape",
"hello|world",
"foo|bar",
"test|data",
"n-gram|model",
"machine|learning",
];
for ngram in &test_ngrams {
let key = compute_shard_key(ngram, 1, &g);
let index = key.as_index().expect("Should be index-based");
assert!(
index < num_shards,
"Index {} out of range for {} shards",
index,
num_shards
);
shard_counts[index] += 1;
}
let non_empty = shard_counts.iter().filter(|&&c| c > 0).count();
assert!(
non_empty >= 2,
"Hash distribution too skewed: only {} non-empty shards",
non_empty
);
}
#[test]
fn test_shard_key_from_index() {
let key = ShardKey::from_index(0);
assert_eq!(key.prefix, "0000");
assert!(key.is_index_based());
assert_eq!(key.as_index(), Some(0));
let key = ShardKey::from_index(42);
assert_eq!(key.prefix, "0042");
assert_eq!(key.as_index(), Some(42));
let key = ShardKey::from_index(9999);
assert_eq!(key.prefix, "9999");
assert_eq!(key.as_index(), Some(9999));
}
#[test]
fn test_shard_key_is_index_based() {
assert!(ShardKey::from_index(0).is_index_based());
assert!(ShardKey::new("0000").is_index_based());
assert!(ShardKey::new("1234").is_index_based());
assert!(!ShardKey::new("th").is_index_based());
assert!(!ShardKey::new("apple").is_index_based());
assert!(!ShardKey::new("_").is_index_based());
}
#[test]
fn test_shard_key_for_file_prefix() {
let g = ShardGranularity::TwoChar;
let key = shard_key_for_file_prefix("th", 2, &g);
assert_eq!(key.prefix, "th");
let key = shard_key_for_file_prefix("a", 2, &g);
assert_eq!(key.prefix, "aa");
let key = shard_key_for_file_prefix("the", 2, &g);
assert_eq!(key.prefix, "th");
}
#[test]
fn test_shard_key_for_file_prefix_cpu_proportional() {
let g = ShardGranularity::CpuProportional {
multiplier: 2,
minimum: 8,
};
let key = shard_key_for_file_prefix("th", 2, &g);
assert!(key.is_index_based());
}
#[test]
fn test_all_shard_keys_first_char() {
let g = ShardGranularity::FirstChar;
let keys = all_shard_keys(&g, 1);
assert_eq!(keys.len(), 26);
assert_eq!(keys[0].prefix, "a");
assert_eq!(keys[25].prefix, "z");
}
#[test]
fn test_all_shard_keys_two_char() {
let g = ShardGranularity::TwoChar;
let keys = all_shard_keys(&g, 2);
assert_eq!(keys.len(), 676); assert_eq!(keys[0].prefix, "aa");
assert_eq!(keys[675].prefix, "zz");
}
#[test]
fn test_all_shard_keys_cpu_proportional() {
let g = ShardGranularity::CpuProportional {
multiplier: 1,
minimum: 16,
};
let keys = all_shard_keys(&g, 1);
let num_shards = g.num_shards();
assert_eq!(keys.len(), num_shards);
for (i, key) in keys.iter().enumerate() {
assert!(key.is_index_based());
assert_eq!(key.as_index(), Some(i));
}
}
#[test]
fn test_shard_key_display() {
let key = ShardKey::new("th");
assert_eq!(format!("{}", key), "th");
let key = ShardKey::with_order("th", 3);
assert_eq!(format!("{}", key), "th:3");
let key = ShardKey::from_index(42);
assert_eq!(format!("{}", key), "0042");
}
#[test]
fn test_shard_key_file_stem() {
let key = ShardKey::new("th");
assert_eq!(key.as_file_stem(), "th");
let key = ShardKey::with_order("th", 3);
assert_eq!(key.as_file_stem(), "th_3");
let key = ShardKey::from_index(42);
assert_eq!(key.as_file_stem(), "0042");
}
#[test]
fn test_ngram_order() {
assert_eq!(ngram_order("apple"), 1);
assert_eq!(ngram_order("apple|pie"), 2);
assert_eq!(ngram_order("the|quick|brown|fox|jumps"), 5);
}
#[test]
fn test_compute_shard_key_from_token_two_char() {
let g = ShardGranularity::TwoChar;
let key = compute_shard_key_from_token("the", 3, &g);
assert_eq!(key.prefix, "th");
let key = compute_shard_key_from_token("apple", 2, &g);
assert_eq!(key.prefix, "ap");
let key = compute_shard_key_from_token("ZEBRA", 1, &g);
assert_eq!(key.prefix, "ze");
let key = compute_shard_key_from_token("a", 2, &g);
assert_eq!(key.prefix, "aa");
let key = compute_shard_key_from_token("123", 1, &g);
assert_eq!(key.prefix, "__");
}
#[test]
fn test_compute_shard_key_from_token_adaptive() {
let g = ShardGranularity::Adaptive;
let key = compute_shard_key_from_token("apple", 1, &g);
assert_eq!(key.prefix, "a");
let key = compute_shard_key_from_token("the", 2, &g);
assert_eq!(key.prefix, "th");
let key = compute_shard_key_from_token("quick", 3, &g);
assert_eq!(key.prefix, "qu");
}
#[test]
fn test_compute_shard_key_from_token_cpu_proportional() {
let g = ShardGranularity::CpuProportional {
multiplier: 2,
minimum: 8,
};
let key = compute_shard_key_from_token("the", 3, &g);
assert!(key.is_index_based());
let key2 = compute_shard_key_from_token("the", 3, &g);
assert_eq!(key.prefix, key2.prefix);
let key3 = compute_shard_key_from_token("apple", 2, &g);
assert!(key3.is_index_based());
}
#[test]
fn test_compute_shard_key_from_token_matches_compute_shard_key() {
let granularities = [
ShardGranularity::FirstChar,
ShardGranularity::TwoChar,
ShardGranularity::Adaptive,
];
for g in &granularities {
let key_from_ngram = compute_shard_key("the|quick|brown", 3, g);
let key_from_token = compute_shard_key_from_token("the", 3, g);
assert_eq!(key_from_ngram, key_from_token, "Mismatch for {:?}", g);
let key_from_ngram = compute_shard_key("apple|pie", 2, g);
let key_from_token = compute_shard_key_from_token("apple", 2, g);
assert_eq!(key_from_ngram, key_from_token, "Mismatch for {:?}", g);
}
}
}