use itertools::*;
use lazy_static::*;
use rand::distributions::*;
use rand::{CryptoRng, Rng};
use sha2::{Digest, Sha256};
use std::collections::HashSet;
use std::ops::DerefMut;
use std::sync::{Mutex, MutexGuard};
use unicode_segmentation::UnicodeSegmentation;
use unidecode::unidecode_char;
use Result::{Err, Ok};
const FILTERED_CHARS: [char; 31] = [
'!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '{', '}', '_', '<', '>', ':', ';', ',', '.',
'"', '\'', '`', '|', '+', '=', '/', '~', '[', ']', '\\', '-',
];
fn should_keep_char(c: &char) -> bool {
!FILTERED_CHARS.contains(c)
}
lazy_static! {
static ref ALL_U32: Uniform<u32> = Uniform::new_inclusive(0u32, u32::max_value());
static ref ONE_TO_TWO_HUNDRED: Uniform<u8> = Uniform::new_inclusive(1, 200);
}
const MAX_STRING_LEN: usize = 200;
pub fn generate_hashes_for_string_with_padding<R: Rng + CryptoRng>(
s: &str,
partition_id: Option<&str>,
salt: &[u8],
rng: &Mutex<R>,
) -> Result<HashSet<u32>, String> {
let mut hashes = generate_hashes_for_string(s, partition_id, salt)?;
let prob = take_lock(&rng).deref_mut().sample(*ONE_TO_TWO_HUNDRED);
let to_add: u8 = {
let r = &mut *take_lock(&rng);
if prob <= 1 {
r.gen_range(1, 200)
} else if prob <= 5 {
r.gen_range(1, 30)
} else if prob <= 50 {
r.gen_range(1, 10)
} else {
r.gen_range(1, 5)
}
};
let pad_len = std::cmp::min(MAX_STRING_LEN - hashes.len(), to_add as usize);
hashes.extend(
take_lock(&rng)
.deref_mut()
.sample_iter(*ALL_U32)
.take(pad_len),
);
Ok(hashes)
}
pub fn generate_hashes_for_string(
s: &str,
partition_id: Option<&str>,
salt: &[u8],
) -> Result<HashSet<u32>, String> {
if s.len() > MAX_STRING_LEN {
Err(format!("The input string is too long. This function only supports strings that are no longer than {} chars.", MAX_STRING_LEN))
} else {
let partial_sha256 = partition_id
.map(|k| k.as_bytes())
.iter()
.chain([salt].iter())
.fold(Sha256::new(), |hasher, k| hasher.chain(k));
let short_hash = |word: &[u8]| -> u32 {
let sha256_hash = partial_sha256.clone().chain(word);
as_u32_be(&sha256_hash.result().into())
};
let result: HashSet<_> = make_tri_grams(s)
.iter()
.map(|tri_gram| short_hash(tri_gram.as_bytes()))
.collect();
Ok(result)
}
}
pub fn transliterate_string(s: &str) -> String {
s.chars()
.filter(should_keep_char)
.map(char_to_trans)
.collect()
}
fn make_tri_grams(s: &str) -> HashSet<String> {
let converted_string = transliterate_string(s);
converted_string
.unicode_words()
.into_iter()
.map(|short_word| {
let short_word_len = short_word.chars().count();
if short_word_len < 3 {
format!("{:-<3}", short_word)
} else {
short_word.to_string()
}
})
.flat_map(|word| word_to_trigrams(&word))
.collect()
}
fn word_to_trigrams(s: &str) -> HashSet<String> {
s.chars()
.tuple_windows()
.map(|(c1, c2, c3)| format!("{}{}{}", c1, c2, c3))
.collect()
}
fn char_to_trans(c: char) -> String {
let trans_string = unidecode_char(c);
if trans_string == "" {
format!("{}", c)
} else {
trans_string.to_lowercase()
}
}
#[inline]
fn as_u32_be(slice: &[u8; 32]) -> u32 {
((slice[0] as u32) << 24)
+ ((slice[1] as u32) << 16)
+ ((slice[2] as u32) << 8)
+ ((slice[3] as u32) << 0)
}
fn take_lock<T>(m: &Mutex<T>) -> MutexGuard<T> {
m.lock().unwrap_or_else(|e| {
let error = format!("Error when acquiring lock: {}", e);
panic!(error);
})
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::ThreadRng;
fn make_set(array: &[&str]) -> HashSet<String> {
array
.into_iter()
.map(|&s| From::from(s))
.collect::<HashSet<_>>()
}
#[test]
fn as_u32_be_known_result() {
let known_result = 16909060u32;
let mut input = [0u8; 32];
input[0] = 1;
input[1] = 2;
input[2] = 3;
input[3] = 4;
let result = as_u32_be(&input);
assert_eq!(result, known_result);
}
#[test]
fn string_transliterated() {
assert_eq!(transliterate_string("Gumby, dammit!"), "gumby dammit");
assert_eq!(transliterate_string("北亰"), "bei jing ");
assert_eq!(transliterate_string("Æneid"), "aeneid");
}
#[test]
fn word_to_trigrams_known() {
let result = word_to_trigrams("five");
assert_eq!(result, make_set(&["fiv", "ive"]));
}
#[test]
fn make_tri_grams_works_multi_word() {
assert_eq!(
make_tri_grams("123 José Núñez 812-111-7654"),
make_set(&[
"123", "jos", "ose", "nun", "une", "nez", "812", "121", "211", "111", "117", "176",
"765", "654",
])
);
}
#[test]
fn make_tri_grams_works_non_ascii() {
assert_eq!(
make_tri_grams("TİRYAKİ"),
make_set(&["tir", "iry", "rya", "yak", "aki"])
);
}
#[test]
fn make_tri_grams_eliminates_duplicates() {
assert_eq!(
make_tri_grams("TİRYAKİ TİRYAKİ"),
make_set(&["tir", "iry", "rya", "yak", "aki"])
);
}
#[test]
fn make_tri_grams_works_short_non_ascii() {
assert_eq!(make_tri_grams("Tİ"), make_set(&["ti-"]));
}
#[test]
fn make_tri_grams_works_multichar_translate() {
assert_eq!(
make_tri_grams("志 豪 İ"),
make_set(&["zhi", "hao", "i--"])
);
}
#[test]
fn make_tri_grams_works_arabic() {
assert_eq!(
make_tri_grams("شريط فو"),
make_set(&["shr", "hry", "ryt", "fw-"])
);
}
#[test]
fn make_tri_grams_works_short_multibyte() {
assert_eq!(
make_tri_grams("\u{102AE}\u{102AF}"),
make_set(&["\u{102AE}\u{102AF}-"])
);
}
#[test]
fn char_to_trans_latinizable() {
assert_eq!(char_to_trans('İ'), "i")
}
#[test]
fn char_to_trans_not_latinizable() {
let c = "\u{102AE}".chars().nth(0).unwrap();
assert_eq!(char_to_trans(c), "\u{102AE}")
}
#[test]
fn generate_hashes_for_string_compute_known_value() -> Result<(), String> {
let result = generate_hashes_for_string("123", Some("foo"), &[0u8; 1])?;
let expected_result = {
let mut hasher = Sha256::new();
hasher.input("foo".as_bytes());
hasher.input([0u8; 1]);
hasher.input("123");
as_u32_be(&(hasher.result().into()))
};
assert_eq!(result, [expected_result].iter().map(|x| *x).collect());
Ok(())
}
#[test]
fn generate_hashes_for_string_with_padding_adds_at_least_one() -> Result<(), String> {
let rng = Mutex::new(ThreadRng::default());
let result = generate_hashes_for_string_with_padding("123", Some("foo"), &[0u8; 1], &rng)?;
assert!(result.len() > 1);
Ok(())
}
#[test]
fn generate_hashes_for_string_with_padding_empty_string() -> Result<(), String> {
let rng = Mutex::new(ThreadRng::default());
let result = generate_hashes_for_string_with_padding("", Some("foo"), &[0u8; 1], &rng)?;
assert!(result.len() >= 1);
Ok(())
}
#[test]
fn generate_hashes_for_string_too_long_errors() -> Result<(), String> {
let rng = ThreadRng::default();
let input: String = rng
.sample_iter(rand::distributions::Alphanumeric)
.take(201)
.collect();
generate_hashes_for_string(&input, Some("foo"), &[0u8; 1]).unwrap_err();
Ok(())
}
}