use core::fmt;
use crate::rng;
pub const DEFAULT_ALPHABET: &[u8] =
b"_-0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
pub const DEFAULT_LENGTH: usize = 21;
pub fn generate() -> String {
custom(DEFAULT_LENGTH, DEFAULT_ALPHABET)
}
pub fn with_length(length: usize) -> String {
custom(length, DEFAULT_ALPHABET)
}
pub fn custom(length: usize, alphabet: &[u8]) -> String {
if alphabet.is_empty() || length == 0 {
return String::new();
}
custom_unchecked(length, alphabet)
}
pub fn try_custom(length: usize, alphabet: &[u8]) -> Result<String, AlphabetError> {
validate_alphabet(alphabet)?;
if length == 0 {
return Ok(String::new());
}
Ok(custom_unchecked(length, alphabet))
}
pub fn validate_alphabet(alphabet: &[u8]) -> Result<(), AlphabetError> {
if alphabet.is_empty() {
return Err(AlphabetError::Empty);
}
let mut seen = [false; 256];
for &b in alphabet {
if seen[b as usize] {
return Err(AlphabetError::Duplicate(b));
}
seen[b as usize] = true;
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AlphabetError {
Empty,
Duplicate(u8),
}
impl fmt::Display for AlphabetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Empty => f.write_str("nanoid: alphabet must be non-empty"),
Self::Duplicate(b) => {
write!(f, "nanoid: duplicate byte 0x{b:02x} in alphabet")
}
}
}
}
impl std::error::Error for AlphabetError {}
fn custom_unchecked(length: usize, alphabet: &[u8]) -> String {
let n = alphabet.len();
if n == 1 {
return core::iter::repeat(alphabet[0] as char)
.take(length)
.collect();
}
let bits = mask_bits(n);
let mask: u64 = (1u64 << bits) - 1;
let mut out = String::with_capacity(length);
let mut buffer: u64 = 0;
let mut buffer_bits: u32 = 0;
let mut placed = 0;
while placed < length {
if buffer_bits < bits {
buffer = rng::next_u64();
buffer_bits = 64;
}
let idx = (buffer & mask) as usize;
buffer >>= bits;
buffer_bits -= bits;
if idx < n {
out.push(alphabet[idx] as char);
placed += 1;
}
}
out
}
#[inline]
const fn mask_bits(n: usize) -> u32 {
usize::BITS - (n - 1).leading_zeros()
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn default_length() {
assert_eq!(generate().len(), 21);
}
#[test]
fn with_length_correct() {
assert_eq!(with_length(10).len(), 10);
}
#[test]
fn custom_alphabet_respected() {
let id = custom(64, b"01");
assert!(id.chars().all(|c| c == '0' || c == '1'));
assert_eq!(id.len(), 64);
}
#[test]
fn unique_ids() {
assert_ne!(generate(), generate());
}
#[test]
fn many_default_unique() {
let mut set = HashSet::new();
for _ in 0..10_000 {
assert!(set.insert(generate()));
}
}
#[test]
fn empty_alphabet_returns_empty() {
assert_eq!(custom(10, &[]), "");
}
#[test]
fn zero_length_returns_empty() {
assert_eq!(custom(0, DEFAULT_ALPHABET), "");
assert_eq!(with_length(0), "");
}
#[test]
fn single_char_alphabet() {
let id = custom(8, b"x");
assert_eq!(id, "xxxxxxxx");
}
#[test]
fn try_custom_rejects_empty() {
assert_eq!(try_custom(8, b""), Err(AlphabetError::Empty));
}
#[test]
fn try_custom_rejects_duplicate() {
let err = try_custom(8, b"abcda").unwrap_err();
assert_eq!(err, AlphabetError::Duplicate(b'a'));
}
#[test]
fn try_custom_accepts_valid() {
let id = try_custom(12, b"abcdef0123").unwrap();
assert_eq!(id.len(), 12);
assert!(id.chars().all(|c| "abcdef0123".contains(c)));
}
#[test]
fn validate_alphabet_paths() {
assert!(validate_alphabet(DEFAULT_ALPHABET).is_ok());
assert_eq!(validate_alphabet(b""), Err(AlphabetError::Empty));
assert_eq!(
validate_alphabet(b"abca"),
Err(AlphabetError::Duplicate(b'a'))
);
}
#[test]
fn non_power_of_two_alphabet_unbiased() {
let alphabet: &[u8] = b"ABCDEFGHIJKLMNOPQ"; let id = custom(170_000, alphabet);
let mut counts = [0usize; 17];
for c in id.bytes() {
let i = alphabet.iter().position(|&b| b == c).unwrap();
counts[i] += 1;
}
for (i, &n) in counts.iter().enumerate() {
assert!(
(8_800..=11_200).contains(&n),
"alphabet[{i}] ({}) count {n} outside expected band",
alphabet[i] as char
);
}
}
#[test]
fn mask_bits_known_values() {
assert_eq!(mask_bits(2), 1);
assert_eq!(mask_bits(8), 3);
assert_eq!(mask_bits(64), 6);
assert_eq!(mask_bits(65), 7);
assert_eq!(mask_bits(256), 8);
}
#[test]
fn length_exact_across_alphabet_sizes() {
let printable: Vec<u8> = (b'!'..=b'~').collect(); for n in [2usize, 7, 16, 33, 64, 65, 93, 94] {
let alphabet = &printable[..n];
let id = custom(50, alphabet);
assert_eq!(id.chars().count(), 50, "size {n}");
assert_eq!(id.len(), 50, "size {n}: ASCII alphabet so bytes==chars");
}
}
#[test]
fn non_ascii_alphabet_counts_chars_not_bytes() {
let printable: Vec<u8> = (0..=255).collect();
let id = custom(30, &printable);
assert_eq!(id.chars().count(), 30);
}
#[test]
fn default_alphabet_has_no_duplicates() {
assert!(validate_alphabet(DEFAULT_ALPHABET).is_ok());
assert_eq!(DEFAULT_ALPHABET.len(), 64);
}
}