cuid2/
lib.rs

1//! # Cuid2
2//!
3//! Secure, collision-resistant ids optimized for horizontal scaling and
4//! performance.
5//!
6//! This is a Rust implementation of the CUID2 algorithm, defined by its
7//! reference implementation [here](https://github.com/paralleldrive/cuid2).
8//!
9//! Please see that repository for a discussion of the benefits of CUIDs, as
10//! well as for the improvements in CUID2 over the original CUID algorithm
11//! (which is also implemented in Rust [here](https://docs.rs/cuid/latest/cuid/)).
12//!
13//! ## Usage
14//!
15//! The simplest usage is to use the `create_id()` function to create an ID:
16//!
17//! ```
18//! use cuid2;
19//!
20//! let id = cuid2::create_id();
21//!
22//! assert_eq!(24, id.len());
23//! ```
24//!
25//! A `cuid()` alias is provided to make this more of a drop-in replacement for
26//! the v1 cuid package:
27//!
28//! ```
29//! use cuid2::cuid;
30//!
31//! let id = cuid();
32//!
33//! assert_eq!(24, id.len());
34//! ```
35//!
36//! If you would like to customize aspects of CUID production, you can create
37//! a constructor with customized properties:
38//!
39//! ```
40//! use cuid2::CuidConstructor;
41//!
42//! let constructor = CuidConstructor::new().with_length(32);
43//!
44//! let id = constructor.create_id();
45//!
46//! assert_eq!(32, id.len());
47//! ```
48//!
49//! If installed with `cargo install`, this package also provides a `cuid2`
50//! binary, which generates a CUID on the command line. It can be used like:
51//!
52//! ```ignore,compile_fail
53//! > cuid2
54//! y3cfw1hafbtezzflns334sb2
55//! ```
56
57use std::{
58    cell::RefCell,
59    collections::hash_map::DefaultHasher,
60    hash::{Hash, Hasher},
61};
62
63// std::time::SystemTime panics on WASM, so use a different library there.
64#[cfg(not(target_family = "wasm"))]
65use std::time::{SystemTime, UNIX_EPOCH};
66#[cfg(target_family = "wasm")]
67use web_time::{SystemTime, UNIX_EPOCH};
68
69use cuid_util::ToBase36;
70use num::bigint;
71use rand::{seq::SliceRandom, thread_rng, Rng};
72use sha3::{Digest, Sha3_512};
73
74// =============================================================================
75// CONSTANTS
76// =============================================================================
77
78// Only public to expose to binary
79#[doc(hidden)]
80pub const DEFAULT_LENGTH: u8 = 24;
81const BIG_LENGTH: u8 = 32;
82// valid characters to start an ID
83const STARTING_CHARS: &str = "abcdefghijklmnopqrstuvwxyz";
84
85// =============================================================================
86// THREAD LOCALS
87// =============================================================================
88// Each thread generating CUIDs gets its own:
89// - 64-bit counter, randomly initialized to some value between 0 and 2056, inclusive
90// - fingerprint, a hash with added entropy, derived from a random number between
91//   2063 and 4125, inclusive, the process ID, and the thread ID
92
93fn fingerprint() -> String {
94    hash(
95        [
96            thread_rng().gen::<u128>().to_be_bytes(),
97            thread_rng().gen::<u128>().to_be_bytes(),
98            #[cfg(not(target_family = "wasm"))]
99            u128::from(std::process::id()).to_be_bytes(),
100            // WASM has no concept of a PID, so just use another random block
101            #[cfg(target_family = "wasm")]
102            thread_rng().gen::<u128>().to_be_bytes(),
103            u128::from(get_thread_id()).to_be_bytes(),
104        ],
105        BIG_LENGTH.into(),
106    )
107}
108
109thread_local! {
110    /// Value used to initialize the counter. After the counter hits u64::MAX, it
111    /// will roll back to this value.
112    // Updated 2023-08-08 to match updated reference implementation, which notes:
113    // > ~22k hosts before 50% chance of initial counter collision
114    // > with a remaining counter range of 9.0e+15 in JavaScript.
115    static COUNTER_INIT: u64 = thread_rng().gen_range(0..476_782_367);
116
117    /// Use an individual counter per thread, starting at a randomly initialized value.
118    ///
119    /// Range of randomly initialized values taken from reference implementation.
120    static COUNTER: RefCell<u64> = COUNTER_INIT.with(|val| RefCell::new(*val));
121
122    /// Fingerprint! The original implementation is a hash of:
123    /// - stringified keys of the global object
124    /// - added entropy
125    ///
126    /// For us, we'll use
127    /// - A few random numbers
128    /// - the process ID
129    /// - the thread ID (which also ensures our CUIDs will be different per thread)
130    ///
131    /// This is pretty non-language, non-system dependent, so it allows us to
132    /// compile to wasm and so on.
133    static FINGERPRINT: String = fingerprint();
134}
135
136// Hashing
137// =======
138
139/// Hash a value, including an additional salt of randomly generated data.
140//
141// Updated 2023-08-08 to match the updated JS implementation, which is:
142//
143// ```js
144// const hash = (input = "") => {
145//   // Drop the first character because it will bias the histogram
146//   // to the left.
147//   return bufToBigInt(sha3(input)).toString(36).slice(1);
148// };
149// ```
150//
151// We don't drop the first character, because it doesn't actually affect the
152// histogram (the comment in the reference implementation is incorrect).
153fn hash<S: AsRef<[u8]>, T: IntoIterator<Item = S>>(input: T, length: u16) -> String {
154    let mut hasher = Sha3_512::new();
155
156    for block in input {
157        hasher.update(block.as_ref());
158    }
159
160    // 512 bits (64 bytes) of data ([u8; 64])
161    let hash = hasher.finalize();
162
163    // We'll convert the bytes directly to a big, unsigned int and then use
164    // its builtin radix conversion.
165    //
166    // We don't use bigint for the rest of our base conversions, because it's
167    // significantly slower, but we use it here since we need to deal with the
168    // 512-bit integer from the hash function.
169    let mut res = bigint::BigUint::from_bytes_be(&hash).to_str_radix(36);
170
171    // Note that truncate panics if the length does not fall on a char boundary,
172    // but we don't need to worry about that since all the chars will be ASCII.
173    res.truncate(length.into());
174
175    res
176}
177
178// Other Utility Functions
179// =======================
180
181/// Return whether a string is a legitimate CUID2
182/// ```rust
183/// use cuid2;
184/// let id = cuid2::create_id();
185/// let empty_id = "";
186/// let too_small = "a";
187/// let too_big = "a1l23j1l2k3j12o8312j3k12j3lj12k3j1lk2j312j3lkj12l3g1kj2h312312lk3j1l2j3lk12j3lkjlj1lk23jl131l2k3jl12j3lk1j2lk3j12lk3h12k3hhl1j2j3";
188/// let non_ascii_alphanumeric = "a#";
189/// let non_first_letter = "1aaa";
190/// let with_underscore = "aaa_1aaa";
191/// assert!(cuid2::is_cuid2(id));
192/// assert!(!cuid2::is_cuid2(empty_id));
193/// assert!(!cuid2::is_cuid2(too_small));
194/// assert!(!cuid2::is_cuid2(too_big));
195/// assert!(!cuid2::is_cuid2(non_ascii_alphanumeric));
196/// assert!(!cuid2::is_cuid2(non_first_letter));
197/// assert!(!cuid2::is_cuid2(with_underscore));
198/// ```
199#[inline]
200pub fn is_cuid2<S: AsRef<str>>(to_check: S) -> bool {
201    const MAX_LENGTH: usize = BIG_LENGTH as usize;
202    is_cuid2_inner::<S, MAX_LENGTH>(to_check)
203}
204
205fn is_cuid2_inner<S: AsRef<str>, const MAX_LENGTH: usize>(to_check: S) -> bool {
206    let to_check = to_check.as_ref().as_bytes();
207
208    if (2..=MAX_LENGTH).contains(&to_check.len()) {
209        if let [first, tail @ ..] = to_check {
210            return STARTING_CHARS.as_bytes().contains(first)
211                && tail.iter().all(|x| matches!(x, b'0'..=b'9' | b'a'..=b'z'));
212        }
213    }
214
215    false
216}
217
218/// Return whether a string is a legitimate CUID.
219///
220/// This is an alias of [is_cuid2]
221#[inline]
222pub fn is_cuid<S: AsRef<str>>(to_check: S) -> bool {
223    is_cuid2(to_check)
224}
225
226/// Creates a random string of the specified length.
227fn create_entropy(length: u16) -> String {
228    let mut rng = thread_rng();
229    let length: usize = length.into();
230
231    // Allocate a string with the appropriate capacity to avoid reallocation.
232    //
233    // The string is generated and then pushed to until its desired length is
234    // reached or exceeded. We therefore allocate enough for the length plus
235    // the maximum value it might be exceeded by. The values pushed to the
236    // string are random numbers from 0 to 36, converted to base 36.
237    // Therefore, the maximum overfill is 36 in base 36, i.e. 10, which is 2
238    // chars
239    let mut result = String::with_capacity(length + 2);
240
241    while result.len() < length {
242        // Matches reference implementation logic as of 2023-08-08, which is:
243        // ```js
244        // entropy = entropy + Math.floor(random() * 36).toString(36);
245        // ```
246        let random_val = rng.gen_range(0u128..36u128);
247        result.push_str(&random_val.to_base_36());
248    }
249
250    result
251}
252
253/// Retrieves the current timestmap and converts to Base36.
254fn get_timestamp() -> String {
255    SystemTime::now()
256        .duration_since(UNIX_EPOCH)
257        // Use timestamp as milliseconds to match JS implementation
258        .map(|time| time.as_millis().to_base_36())
259        // Panic safety: `.duration_since()` fails if the end time is not
260        // later than the start time, so this will only fail if the system
261        // time is before 1970-01-01. It is impossible on Unix systems to set
262        // a time before then, since the entire system uses a 32 or 64 bit
263        // unsigned integer for time, where zero is midnight 1970-01-01.
264        //
265        // If you are on a system that for some reason both can be and needs to
266        // be set >50 years in the past AND this library not working is a
267        // problem for you, please feel free to reach out.
268        .expect(
269            "Failed to calculate system timestamp! Current system time may be \
270                 set to before the Unix epoch, or time may otherwise be broken. \
271                 Cannot continue",
272        )
273}
274
275/// Retrieves and increments the counter value.
276fn get_count() -> u64 {
277    COUNTER.with(|cell| {
278        cell.replace_with(|counter| {
279            counter
280                .checked_add(1)
281                // if we hit u64::MAX, roll back to the original thread-local
282                // initialization value
283                .unwrap_or_else(|| COUNTER_INIT.with(|x| *x))
284        })
285    })
286}
287
288/// Retrieves the thread-local fingerprint.
289fn get_fingerprint() -> String {
290    FINGERPRINT.with(|x| x.clone())
291}
292
293/// Retrieves the current thread's ID.
294fn get_thread_id() -> u64 {
295    // ThreadId doesn't implement debug or display, but it does implement Hash,
296    // so we can get the hash value to use in our fingerprint.
297    let mut hasher = DefaultHasher::new();
298    std::thread::current().id().hash(&mut hasher);
299    hasher.finish()
300}
301
302// =============================================================================
303// CUID CONSTRUCTION
304// =============================================================================
305
306/// Provides customization of CUID generation.
307///
308/// ```
309/// use cuid2::CuidConstructor;
310///
311/// let mut constructor = CuidConstructor::new();
312/// assert_eq!(24, constructor.create_id().len());
313///
314/// constructor.set_length(16);
315/// assert_eq!(16, constructor.create_id().len());
316///
317/// assert_eq!(32, CuidConstructor::new().with_length(32).create_id().len());
318/// ```
319pub struct CuidConstructor {
320    length: u16,
321    counter: fn() -> u64,
322    fingerprinter: fn() -> String,
323}
324impl CuidConstructor {
325    /// Creates a new constructor with default settings.
326    pub const fn new() -> Self {
327        Self {
328            length: DEFAULT_LENGTH as u16,
329            counter: get_count,
330            fingerprinter: get_fingerprint,
331        }
332    }
333
334    /// Returns a new constructor that will generate CUIDs with the specified length.
335    ///
336    /// # Panics
337    ///
338    /// Panics if `length` is less than 2.
339    ///
340    pub const fn with_length(self, length: u16) -> Self {
341        if length < 2 {
342            panic!("CUID length must be at least 2")
343        }
344        Self { length, ..self }
345    }
346
347    /// Returns a new constructor with the specified counter function.
348    pub const fn with_counter(self, counter: fn() -> u64) -> Self {
349        Self { counter, ..self }
350    }
351
352    /// Returns a new constructor with the specified fingerprinter function.
353    pub const fn with_fingerprinter(self, fingerprinter: fn() -> String) -> Self {
354        Self {
355            fingerprinter,
356            ..self
357        }
358    }
359
360    /// Sets the length for CUIDs generated by this constrctor.
361    ///
362    /// # Panics
363    ///
364    /// Panics if `length` is less than 2.
365    ///
366    pub fn set_length(&mut self, length: u16) {
367        if length < 2 {
368            panic!("CUID length must be at least 2")
369        }
370        self.length = length;
371    }
372
373    /// Sets the counter function for this constructor.
374    pub fn set_counter(&mut self, counter: fn() -> u64) {
375        self.counter = counter;
376    }
377
378    /// Sets the fingerperinter function for this constructor.
379    pub fn set_fingerprinter(&mut self, fingerprinter: fn() -> String) {
380        self.fingerprinter = fingerprinter;
381    }
382
383    /// Creates a new CUID.
384    #[inline]
385    pub fn create_id(&self) -> String {
386        let time = get_timestamp();
387
388        let entropy = create_entropy(self.length);
389
390        let count = (self.counter)().to_base_36();
391
392        let fingerprint = (self.fingerprinter)();
393
394        // Construct the main part of the ID body by hashing the various inputs
395        let id_body = hash(
396            [
397                time.as_bytes(),
398                entropy.as_bytes(),
399                count.as_bytes(),
400                fingerprint.as_bytes(),
401            ],
402            // The hash should be the desired total length minus 1 character
403            // for the starting char.
404            self.length - 1,
405        );
406
407        // TODO check if index access makes a perf difference here
408        let first_letter = (*STARTING_CHARS
409            .as_bytes()
410            // Panic safety: choose() only returns None if the slice is empty,
411            // and STARTING_CHARS is a statically defined non-empty slice.
412            .choose(&mut thread_rng())
413            .expect("STARTING_CHARS cannot be empty")) as char;
414
415        // Return only the requested length
416        format!("{first_letter}{id_body}")
417    }
418}
419impl Default for CuidConstructor {
420    fn default() -> Self {
421        Self::new()
422    }
423}
424
425/// Use a static constructor for create_id() so that we don't need to pay the
426/// (minimal, probably trivial) cost of constructor creation when called via
427/// `create_id()`.
428static DEFAULT_CONSTRUCTOR: CuidConstructor = CuidConstructor::new();
429
430const SLUG_LENGTH: u16 = 10;
431
432static SLUG_CONSTRUCTOR: CuidConstructor = CuidConstructor::new().with_length(SLUG_LENGTH);
433
434/// Creates a new CUID.
435#[inline]
436pub fn create_id() -> String {
437    DEFAULT_CONSTRUCTOR.create_id()
438}
439
440/// Creates a new CUID.
441///
442/// Alias for `created_id()`, which is the interface defined in the reference
443/// implementation. The `cuid()` interface allows easier drop-in replacement
444/// for crates using the v1 `cuid` crate.
445#[inline]
446pub fn cuid() -> String {
447    create_id()
448}
449
450/// Creates a new CUID slug, which is just a CUID with a length of 10 characters.
451#[inline]
452pub fn slug() -> String {
453    SLUG_CONSTRUCTOR.create_id()
454}
455
456/// Return whether a string looks like it could be a legitimate CUID slug.
457#[inline]
458pub fn is_slug<S: AsRef<str>>(to_check: S) -> bool {
459    const MAX_LENGTH: usize = SLUG_LENGTH as usize;
460    is_cuid2_inner::<S, MAX_LENGTH>(to_check)
461}
462
463#[cfg(test)]
464mod test {
465    use std::{collections::HashSet, thread};
466
467    use super::*;
468
469    /// Run an already-defined test in WASM as well.
470    macro_rules! wasm_test {
471        ($name:ident) => {
472            paste::paste! {
473                #[wasm_bindgen_test::wasm_bindgen_test]
474                fn [<wasm_ $name>]() {
475                    $name()
476                }
477            }
478        };
479    }
480
481    #[test]
482    fn counter_increments() {
483        let start = get_count();
484        let next = get_count();
485
486        // concurrent test may have also incremented
487        assert!(next > start);
488    }
489    wasm_test!(counter_increments);
490
491    #[test]
492    fn cuid_generation() {
493        assert!(is_cuid(cuid()))
494    }
495    wasm_test!(cuid_generation);
496
497    // lesser version of the collisions test for WASM
498    #[wasm_bindgen_test::wasm_bindgen_test]
499    fn wasm_collisions() {
500        let count = 10_000;
501        let cuids = (0..count).fold(HashSet::with_capacity(count), |mut acc, _| {
502            acc.insert(cuid());
503            acc
504        });
505        assert_eq!(count, cuids.len());
506    }
507
508    #[cfg(not(target_family = "wasm"))] // uses num_cpus, which we can't compile on wasm
509    #[test]
510    #[ignore] // slow: run explicitly when desired
511    fn collisions() {
512        // generate ~10e6 IDs across all available cores
513
514        use wasm_bindgen_test::wasm_bindgen_test;
515        let cores = num_cpus::get();
516        let per_core = 10_000_000 / cores;
517
518        // collect to force spawning the threads instead of just holding them lazily
519        #[allow(clippy::needless_collect)]
520        let threads = (0..cores)
521            .map(|_| thread::spawn(move || (0..per_core).map(|_| create_id()).collect::<Vec<_>>()))
522            .collect::<Vec<_>>();
523
524        let res = threads
525            .into_iter()
526            .flat_map(|handle| handle.join().unwrap())
527            .collect::<Vec<_>>();
528
529        // All IDs are unique
530        assert_eq!(res.iter().collect::<HashSet<_>>().len(), res.len())
531    }
532
533    /// Asserts that CUIDs are uniformly distributed, ignoring the first
534    /// character.
535    ///
536    /// See https://github.com/paralleldrive/cuid2/blob/b5665387fdf7f947e135f030a545df22c5010a7d/src/test-utils.js
537    /// and https://github.com/paralleldrive/cuid2/blob/b5665387fdf7f947e135f030a545df22c5010a7d/src/histogram.js
538    #[test]
539    #[ignore] // slow: run explicitly when desired
540    fn distribution() {
541        let count = 1_000_000;
542
543        let buckets = [0_u64; 20];
544        let bucket_count = bigint::BigUint::from(buckets.len());
545
546        let histogram = (0..count)
547            .map(|_| create_id())
548            // parse the ID (minus starting char) as a base36 number
549            .map(|id| bigint::BigUint::parse_bytes(id[1..].as_bytes(), 36).unwrap())
550            // Determine its bucket number.
551            // We know the bucket number will be <20, so we .to_u32_digits()
552            // and grab what should be the only item.
553            .map(|num| {
554                let bucket_number = &num % &bucket_count;
555                let digits = bucket_number.to_u32_digits();
556                assert!(digits.len() < 2, "{num}: {bucket_number}: {digits:?}");
557                digits.first().copied().unwrap_or(0)
558            })
559            // create the histogram. For each bucket number, increment the count
560            .fold(buckets, |mut histogram, bucket_num| {
561                histogram[bucket_num as usize] += 1;
562                histogram
563            });
564
565        let expected_bucket_size = count / histogram.len();
566        let tolerance = 0.05;
567        let max_bucket_size = (expected_bucket_size as f64 * (1.0 + tolerance)).round() as u64;
568        let min_bucket_size = (expected_bucket_size as f64 * (1.0 - tolerance)).round() as u64;
569
570        histogram
571            .into_iter()
572            .enumerate()
573            .for_each(|(idx, bucket_size)| {
574                assert!(bucket_size > min_bucket_size, "bucket {idx} too small");
575                assert!(bucket_size < max_bucket_size, "bucket {idx} too big");
576            })
577    }
578}