gel_auth/scram/stringprep.rs
1use core::str;
2use roaring::RoaringBitmap;
3use std::{ops::Range, sync::OnceLock};
4use unicode_normalization::UnicodeNormalization;
5
6/// Normalize the password using the SASLprep algorithm from RFC4013.
7///
8/// # Examples
9///
10/// ```
11/// # use gel_auth::scram::stringprep::*;
12/// assert_eq!(sasl_normalize_password_bytes(b"password").as_ref(), b"password");
13/// assert_eq!(sasl_normalize_password_bytes("passw\u{00A0}rd".as_bytes()).as_ref(), b"passw rd");
14/// assert_eq!(sasl_normalize_password_bytes("pass\u{200B}word".as_bytes()).as_ref(), b"password");
15/// // This test case demonstrates that invalid UTF-8 sequences are returned unchanged.
16/// // The bytes 0xFF, 0xFE, 0xFD do not form a valid UTF-8 sequence, so the function
17/// // should return them as-is without attempting to normalize or modify them.
18/// assert_eq!(sasl_normalize_password_bytes(&[0xFF, 0xFE, 0xFD]).as_ref(), &[0xFF, 0xFE, 0xFD]);
19/// ```
20pub fn sasl_normalize_password_bytes(s: &[u8]) -> Cow<[u8]> {
21 if s.is_ascii() {
22 Cow::Borrowed(s)
23 } else if let Ok(s) = str::from_utf8(s) {
24 match sasl_normalize_password(s) {
25 Cow::Borrowed(s) => Cow::Borrowed(s.as_bytes()),
26 Cow::Owned(s) => Cow::Owned(s.into()),
27 }
28 } else {
29 Cow::Borrowed(s)
30 }
31}
32
33/// Normalize the password using the SASLprep from RFC4013.
34///
35/// # Examples
36///
37/// ```
38/// # use gel_auth::scram::stringprep::*;
39/// assert_eq!(sasl_normalize_password("password").as_ref(), "password");
40/// assert_eq!(sasl_normalize_password("passw\u{00A0}rd").as_ref(), "passw rd");
41/// assert_eq!(sasl_normalize_password("pass\u{200B}word").as_ref(), "password");
42/// assert_eq!(sasl_normalize_password("パスワード").as_ref(), "パスワード"); // precomposed Japanese
43/// assert_eq!(sasl_normalize_password("パスワード").as_ref(), "パスワード"); // half-width to full-width katakana
44/// assert_eq!(sasl_normalize_password("\u{0061}\u{0308}"), "\u{00E4}"); // a + combining diaeresis -> ä
45/// assert_eq!(sasl_normalize_password("\u{00E4}"), "\u{00E4}"); // precomposed ä
46/// assert_eq!(sasl_normalize_password("\u{0041}\u{0308}"), "\u{00C4}"); // A + combining diaeresis -> Ä
47/// assert_eq!(sasl_normalize_password("\u{00C4}"), "\u{00C4}"); // precomposed Ä
48/// assert_eq!(sasl_normalize_password("\u{0627}\u{0644}\u{0639}\u{0631}\u{0628}\u{064A}\u{0629}"), "\u{0627}\u{0644}\u{0639}\u{0631}\u{0628}\u{064A}\u{0629}"); // Arabic (RandALCat)
49/// ```
50pub fn sasl_normalize_password(s: &str) -> Cow<str> {
51 if s.is_ascii() {
52 return Cow::Borrowed(s);
53 }
54
55 let mut normalized = String::with_capacity(s.len());
56
57 // Step 1 of SASLPrep: Map. Per the algorithm, we map non-ascii space
58 // characters to ASCII spaces (\x20 or \u0020, but we will use ' ') and
59 // commonly mapped to nothing characters are removed
60 // Table C.1.2 -- non-ASCII spaces
61 // Table B.1 -- "Commonly mapped to nothing"
62 for c in s.chars() {
63 if !maps_to_nothing::is_char_included(c as u32) {
64 if maps_to_space::is_char_included(c as u32) {
65 normalized.push(' ');
66 } else {
67 normalized.push(c);
68 }
69 }
70 }
71
72 // If at this point the password is empty, PostgreSQL uses the original
73 // password
74 if normalized.is_empty() {
75 return Cow::Borrowed(s);
76 }
77
78 // Step 2 of SASLPrep: Normalize. Normalize the password using the
79 // Unicode normalization algorithm to NFKC form
80 let normalized = normalized.chars().nfkc().collect::<String>();
81
82 // If the password is not empty, PostgreSQL uses the original password
83 if normalized.is_empty() {
84 return Cow::Borrowed(s);
85 }
86
87 // Step 3 of SASLPrep: Prohibited characters. If PostgreSQL detects any
88 // of the prohibited characters in SASLPrep, it will use the original
89 // password
90 // We also include "unassigned code points" in the prohibited character
91 // category as PostgreSQL does the same
92 if normalized.chars().any(is_saslprep_prohibited) {
93 return Cow::Borrowed(s);
94 }
95
96 // Step 4 of SASLPrep: Bi-directional characters. PostgreSQL follows the
97 // rules for bi-directional characters laid on in RFC3454 Sec. 6 which
98 // are:
99 // 1. Characters in RFC 3454 Sec 5.8 are prohibited (C.8)
100 // 2. If a string contains a RandALCat character, it cannot contain any
101 // LCat character
102 // 3. If the string contains any RandALCat character, a RandALCat
103 // character must be the first and last character of the string
104 // RandALCat characters are found in table D.1, whereas LCat are in D.2.
105 // A RandALCat character is a character with unambiguously right-to-left
106 // directionality.
107 let first_char = normalized.chars().next().unwrap();
108 let last_char = normalized.chars().last().unwrap();
109
110 let contains_rand_al_cat = normalized
111 .chars()
112 .any(|c| table_d1::is_char_included(c as u32));
113 if contains_rand_al_cat {
114 let contains_l_cat = normalized
115 .chars()
116 .any(|c| table_d2::is_char_included(c as u32));
117 if !table_d1::is_char_included(first_char as u32)
118 || !table_d1::is_char_included(last_char as u32)
119 || contains_l_cat
120 {
121 return Cow::Borrowed(s);
122 }
123 }
124
125 // return the normalized password
126 Cow::Owned(normalized)
127}
128
129#[doc(hidden)]
130#[macro_export]
131macro_rules! __process_ranges {
132 (
133 $name:ident =>
134 $( ($first:literal, $last:literal) )*
135 ) => {
136 pub mod $name {
137 #[allow(unused)]
138 pub const RANGES: [std::ops::Range<u32>; [$($first),*].len()] = [
139 $(
140 $first..$last,
141 )*
142 ];
143
144 #[allow(non_contiguous_range_endpoints)]
145 #[allow(unused)]
146 pub fn is_char_included(c: u32) -> bool {
147 match c {
148 $(
149 $first..$last => true,
150 )*
151 _ => false,
152 }
153 }
154 }
155 };
156}
157use std::borrow::Cow;
158
159pub(crate) use __process_ranges as process_ranges;
160
161use super::stringprep_table::{maps_to_nothing, maps_to_space, not_prohibited, table_d1, table_d2};
162
163fn create_bitmap_from_ranges(ranges: &[Range<u32>]) -> RoaringBitmap {
164 let mut bitmap = RoaringBitmap::new();
165 for range in ranges {
166 bitmap.insert_range(range.clone());
167 }
168 bitmap
169}
170
171static NOT_PROHIBITED_BITMAP: std::sync::OnceLock<RoaringBitmap> = OnceLock::new();
172
173fn get_not_prohibited_bitmap() -> &'static RoaringBitmap {
174 NOT_PROHIBITED_BITMAP.get_or_init(|| create_bitmap_from_ranges(¬_prohibited::RANGES))
175}
176
177#[inline(always)]
178fn is_saslprep_prohibited(c: char) -> bool {
179 !get_not_prohibited_bitmap().contains(c as u32)
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_prohibited() {
188 assert!(is_saslprep_prohibited('\0'));
189 assert!(is_saslprep_prohibited('\u{100000}'));
190 }
191
192 #[test]
193 fn generate_roaring_bitmap() {
194 let bitmap = create_bitmap_from_ranges(¬_prohibited::RANGES);
195
196 // You can save the bitmap to a file or use it in other ways
197 // For example, to save it to a file:
198 // use std::fs::File;
199 // use std::io::BufWriter;
200 // let file = File::create("saslprep_prohibited.bin").unwrap();
201 // let mut writer = BufWriter::new(file);
202 // bitmap.serialize_into(&mut writer).unwrap();
203
204 // Print some statistics about the bitmap
205 println!("Bitmap cardinality: {}", bitmap.len());
206 println!("Bitmap size in bytes: {}", bitmap.serialized_size());
207 }
208}