use core::str;
use roaring::RoaringBitmap;
use std::{ops::Range, sync::OnceLock};
use unicode_normalization::UnicodeNormalization;
pub fn sasl_normalize_password_bytes(s: &[u8]) -> Cow<[u8]> {
if s.is_ascii() {
Cow::Borrowed(s)
} else if let Ok(s) = str::from_utf8(s) {
match sasl_normalize_password(s) {
Cow::Borrowed(s) => Cow::Borrowed(s.as_bytes()),
Cow::Owned(s) => Cow::Owned(s.into()),
}
} else {
Cow::Borrowed(s)
}
}
pub fn sasl_normalize_password(s: &str) -> Cow<str> {
if s.is_ascii() {
return Cow::Borrowed(s);
}
let mut normalized = String::with_capacity(s.len());
for c in s.chars() {
if !maps_to_nothing::is_char_included(c as u32) {
if maps_to_space::is_char_included(c as u32) {
normalized.push(' ');
} else {
normalized.push(c);
}
}
}
if normalized.is_empty() {
return Cow::Borrowed(s);
}
let normalized = normalized.chars().nfkc().collect::<String>();
if normalized.is_empty() {
return Cow::Borrowed(s);
}
if normalized.chars().any(is_saslprep_prohibited) {
return Cow::Borrowed(s);
}
let first_char = normalized.chars().next().unwrap();
let last_char = normalized.chars().last().unwrap();
let contains_rand_al_cat = normalized
.chars()
.any(|c| table_d1::is_char_included(c as u32));
if contains_rand_al_cat {
let contains_l_cat = normalized
.chars()
.any(|c| table_d2::is_char_included(c as u32));
if !table_d1::is_char_included(first_char as u32)
|| !table_d1::is_char_included(last_char as u32)
|| contains_l_cat
{
return Cow::Borrowed(s);
}
}
Cow::Owned(normalized)
}
#[doc(hidden)]
#[macro_export]
macro_rules! __process_ranges {
(
$name:ident =>
$( ($first:literal, $last:literal) )*
) => {
pub mod $name {
#[allow(unused)]
pub const RANGES: [std::ops::Range<u32>; [$($first),*].len()] = [
$(
$first..$last,
)*
];
#[allow(non_contiguous_range_endpoints)]
#[allow(unused)]
pub fn is_char_included(c: u32) -> bool {
match c {
$(
$first..$last => true,
)*
_ => false,
}
}
}
};
}
use std::borrow::Cow;
pub(crate) use __process_ranges as process_ranges;
use super::stringprep_table::{maps_to_nothing, maps_to_space, not_prohibited, table_d1, table_d2};
fn create_bitmap_from_ranges(ranges: &[Range<u32>]) -> RoaringBitmap {
let mut bitmap = RoaringBitmap::new();
for range in ranges {
bitmap.insert_range(range.clone());
}
bitmap
}
static NOT_PROHIBITED_BITMAP: std::sync::OnceLock<RoaringBitmap> = OnceLock::new();
fn get_not_prohibited_bitmap() -> &'static RoaringBitmap {
NOT_PROHIBITED_BITMAP.get_or_init(|| create_bitmap_from_ranges(¬_prohibited::RANGES))
}
#[inline(always)]
fn is_saslprep_prohibited(c: char) -> bool {
!get_not_prohibited_bitmap().contains(c as u32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prohibited() {
assert!(is_saslprep_prohibited('\0'));
assert!(is_saslprep_prohibited('\u{100000}'));
}
#[test]
fn generate_roaring_bitmap() {
let bitmap = create_bitmap_from_ranges(¬_prohibited::RANGES);
println!("Bitmap cardinality: {}", bitmap.len());
println!("Bitmap size in bytes: {}", bitmap.serialized_size());
}
}