use std::{
cell::RefCell,
hash::{Hash, Hasher},
};
#[cfg(not(target_family = "wasm"))]
use std::time::{SystemTime, UNIX_EPOCH};
#[cfg(target_family = "wasm")]
use web_time::{SystemTime, UNIX_EPOCH};
use num::bigint;
use rand::RngExt;
use sha3::{Digest, Sha3_512};
#[doc(hidden)]
pub const DEFAULT_LENGTH: u8 = 24;
const BIG_LENGTH: u8 = 32;
const STARTING_CHARS: &[u8] = b"abcdefghijklmnopqrstuvwxyz";
fn fingerprint() -> String {
let mut rng = rand::rng();
hash(
[
rng.random::<u128>().to_be_bytes(),
rng.random::<u128>().to_be_bytes(),
#[cfg(not(target_family = "wasm"))]
u128::from(std::process::id()).to_be_bytes(),
#[cfg(target_family = "wasm")]
rng.random::<u128>().to_be_bytes(),
u128::from(get_thread_id()).to_be_bytes(),
],
BIG_LENGTH.into(),
)
}
thread_local! {
static COUNTER_INIT: u64 = rand::random_range(0..476_782_367);
static COUNTER: RefCell<u64> = COUNTER_INIT.with(|val| RefCell::new(*val));
static FINGERPRINT: String = fingerprint();
}
fn hash<S: AsRef<[u8]>, T: IntoIterator<Item = S>>(input: T, length: u16) -> String {
let mut hasher = Sha3_512::new();
for block in input {
hasher.update(block.as_ref());
}
let hash = hasher.finalize();
let mut res = bigint::BigUint::from_bytes_be(&hash).to_str_radix(36);
res.truncate(length.into());
res
}
#[inline]
pub fn is_cuid2<S: AsRef<str>>(to_check: S) -> bool {
const MAX_LENGTH: usize = BIG_LENGTH as usize;
is_cuid2_inner::<S, MAX_LENGTH>(to_check)
}
fn is_cuid2_inner<S: AsRef<str>, const MAX_LENGTH: usize>(to_check: S) -> bool {
let to_check = to_check.as_ref().as_bytes();
if (2..=MAX_LENGTH).contains(&to_check.len())
&& let [first, tail @ ..] = to_check
{
return STARTING_CHARS.contains(first)
&& tail.iter().all(|x| matches!(x, b'0'..=b'9' | b'a'..=b'z'));
}
false
}
#[inline]
pub fn is_cuid<S: AsRef<str>>(to_check: S) -> bool {
is_cuid2(to_check)
}
fn create_entropy(length: u16, rng: &mut impl RngExt) -> String {
let length: usize = length.into();
let mut result = String::with_capacity(length);
while result.len() < length {
let range_bottom = 0;
let range_top = 36;
debug_assert_eq!(0, range_bottom);
debug_assert_eq!(36, range_top);
let random_val = rng.random_range(range_bottom..range_top);
debug_assert!(random_val < 36);
result.push(char::from_digit(random_val, 36).expect("range is within radix"));
}
result
}
fn get_timestamp() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|time| time.as_millis())
.expect(
"Failed to calculate system timestamp! Current system time may be \
set to before the Unix epoch, or time may otherwise be broken. \
Cannot continue",
)
}
#[inline]
fn get_count() -> u64 {
COUNTER.with(|cell| cell.replace_with(|counter| counter.wrapping_add(1)))
}
#[inline]
fn get_fingerprint() -> String {
FINGERPRINT.with(|x| x.clone())
}
fn get_thread_id() -> u64 {
let mut hasher = ahash::AHasher::default();
std::thread::current().id().hash(&mut hasher);
hasher.finish()
}
pub struct CuidConstructor {
length: u16,
counter: fn() -> u64,
fingerprinter: fn() -> String,
}
impl CuidConstructor {
pub const fn new() -> Self {
Self {
length: DEFAULT_LENGTH as u16,
counter: get_count,
fingerprinter: get_fingerprint,
}
}
pub const fn with_length(self, length: u16) -> Self {
if length < 2 {
panic!("CUID length must be at least 2")
}
Self { length, ..self }
}
pub const fn with_counter(self, counter: fn() -> u64) -> Self {
Self { counter, ..self }
}
pub const fn with_fingerprinter(self, fingerprinter: fn() -> String) -> Self {
Self {
fingerprinter,
..self
}
}
pub fn set_length(&mut self, length: u16) {
if length < 2 {
panic!("CUID length must be at least 2")
}
self.length = length;
}
pub fn set_counter(&mut self, counter: fn() -> u64) {
self.counter = counter;
}
pub fn set_fingerprinter(&mut self, fingerprinter: fn() -> String) {
self.fingerprinter = fingerprinter;
}
#[inline]
pub fn create_id(&self) -> String {
let time = get_timestamp();
let mut rng = rand::rng();
let entropy = create_entropy(self.length, &mut rng);
let count = (self.counter)();
let fingerprint = (self.fingerprinter)();
let id_body = hash(
[
&time.to_be_bytes(),
entropy.as_bytes(),
&count.to_be_bytes(),
fingerprint.as_bytes(),
],
self.length - 1,
);
let letter_idx = rng.random_range(0..STARTING_CHARS.len());
let first_letter = STARTING_CHARS[letter_idx];
let mut id = String::with_capacity(id_body.len() + 1);
id.push(first_letter.into());
id.push_str(&id_body);
id
}
}
impl Default for CuidConstructor {
fn default() -> Self {
Self::new()
}
}
static DEFAULT_CONSTRUCTOR: CuidConstructor = CuidConstructor::new();
const SLUG_LENGTH: u16 = 10;
static SLUG_CONSTRUCTOR: CuidConstructor = CuidConstructor::new().with_length(SLUG_LENGTH);
#[inline]
pub fn create_id() -> String {
DEFAULT_CONSTRUCTOR.create_id()
}
#[inline]
pub fn cuid() -> String {
create_id()
}
#[inline]
pub fn slug() -> String {
SLUG_CONSTRUCTOR.create_id()
}
#[inline]
pub fn is_slug<S: AsRef<str>>(to_check: S) -> bool {
const MAX_LENGTH: usize = SLUG_LENGTH as usize;
is_cuid2_inner::<S, MAX_LENGTH>(to_check)
}
#[cfg(test)]
mod test {
use std::{collections::HashSet, thread};
use super::*;
macro_rules! wasm_test {
($name:ident) => {
paste::paste! {
#[wasm_bindgen_test::wasm_bindgen_test]
fn [<wasm_ $name>]() {
$name()
}
}
};
}
#[test]
fn counter_increments() {
let start = get_count();
let next = get_count();
assert!(next > start);
}
wasm_test!(counter_increments);
#[test]
fn cuid_generation() {
assert!(is_cuid(cuid()))
}
wasm_test!(cuid_generation);
#[wasm_bindgen_test::wasm_bindgen_test]
fn wasm_collisions() {
let count = 10_000;
let cuids = (0..count).fold(HashSet::with_capacity(count), |mut acc, _| {
acc.insert(cuid());
acc
});
assert_eq!(count, cuids.len());
}
#[cfg(not(target_family = "wasm"))] #[test]
#[ignore] fn collisions() {
let cores = num_cpus::get();
let per_core = 10_000_000 / cores;
#[allow(clippy::needless_collect)]
let threads = (0..cores)
.map(|_| thread::spawn(move || (0..per_core).map(|_| create_id()).collect::<Vec<_>>()))
.collect::<Vec<_>>();
let res = threads
.into_iter()
.flat_map(|handle| handle.join().unwrap())
.collect::<Vec<_>>();
assert_eq!(res.iter().collect::<HashSet<_>>().len(), res.len())
}
#[test]
#[ignore] fn distribution() {
let count = 1_000_000;
let buckets = [0_u64; 20];
let bucket_count = bigint::BigUint::from(buckets.len());
let histogram = (0..count)
.map(|_| create_id())
.map(|id| bigint::BigUint::parse_bytes(&id.as_bytes()[1..], 36).unwrap())
.map(|num| {
let bucket_number = &num % &bucket_count;
let digits = bucket_number.to_u32_digits();
assert!(digits.len() < 2, "{num}: {bucket_number}: {digits:?}");
digits.first().copied().unwrap_or(0)
})
.fold(buckets, |mut histogram, bucket_num| {
histogram[bucket_num as usize] += 1;
histogram
});
let expected_bucket_size = count / histogram.len();
let tolerance = 0.05;
let max_bucket_size = (expected_bucket_size as f64 * (1.0 + tolerance)).round() as u64;
let min_bucket_size = (expected_bucket_size as f64 * (1.0 - tolerance)).round() as u64;
histogram
.into_iter()
.enumerate()
.for_each(|(idx, bucket_size)| {
assert!(bucket_size > min_bucket_size, "bucket {idx} too small");
assert!(bucket_size < max_bucket_size, "bucket {idx} too big");
})
}
}