reifydb_type/util/
base64.rs1const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
7const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
8
9pub struct Engine {
11 alphabet: &'static [u8],
12 use_padding: bool,
13}
14
15impl Engine {
16 pub const STANDARD: Engine = Engine {
18 alphabet: BASE64_CHARS,
19 use_padding: true,
20 };
21
22 pub const STANDARD_NO_PAD: Engine = Engine {
24 alphabet: BASE64_CHARS,
25 use_padding: false,
26 };
27
28 pub const URL_SAFE_NO_PAD: Engine = Engine {
30 alphabet: BASE64_URL_CHARS,
31 use_padding: false,
32 };
33
34 pub fn encode(&self, input: &[u8]) -> String {
36 if input.is_empty() {
37 return String::new();
38 }
39
40 let mut result = String::new();
41 let mut i = 0;
42
43 while i < input.len() {
44 let b1 = input[i];
45 let b2 = if i + 1 < input.len() {
46 input[i + 1]
47 } else {
48 0
49 };
50 let b3 = if i + 2 < input.len() {
51 input[i + 2]
52 } else {
53 0
54 };
55
56 let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
57
58 result.push(self.alphabet[(n >> 18) & 63] as char);
59 result.push(self.alphabet[(n >> 12) & 63] as char);
60
61 if i + 1 < input.len() {
62 result.push(self.alphabet[(n >> 6) & 63] as char);
63 if i + 2 < input.len() {
64 result.push(self.alphabet[n & 63] as char);
65 } else if self.use_padding {
66 result.push('=');
67 }
68 } else if self.use_padding {
69 result.push('=');
70 result.push('=');
71 }
72
73 i += 3;
74 }
75
76 result
77 }
78
79 pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
81 if !self.use_padding && input.contains('=') {
83 return Err(DecodeError::UnexpectedPadding);
84 }
85
86 if self.use_padding && input.contains('=') {
88 let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
90 let padding_count = input.len() - padding_start;
91
92 if padding_count > 2 {
95 return Err(DecodeError::InvalidPadding);
96 }
97
98 if padding_start > 0 && input[..padding_start].contains('=') {
100 return Err(DecodeError::InvalidPadding);
101 }
102
103 if input.len() % 4 != 0 {
105 return Err(DecodeError::InvalidPadding);
106 }
107
108 let non_padding_in_last_quantum = 4 - padding_count;
114 if non_padding_in_last_quantum < 2 {
115 return Err(DecodeError::InvalidPadding);
116 }
117 }
118
119 let input = input.trim_end_matches('=');
120 if input.is_empty() {
121 return Ok(Vec::new());
122 }
123
124 let mut result = Vec::new();
125 let mut accumulator = 0u32;
126 let mut bits_collected = 0;
127
128 for ch in input.chars() {
129 let value = self.char_to_value(ch)?;
130 accumulator = (accumulator << 6) | (value as u32);
131 bits_collected += 6;
132
133 if bits_collected >= 8 {
134 bits_collected -= 8;
135 result.push((accumulator >> bits_collected) as u8);
136 accumulator &= (1 << bits_collected) - 1;
137 }
138 }
139
140 Ok(result)
141 }
142
143 fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
144 let byte = ch as u8;
145 self.alphabet
146 .iter()
147 .position(|&b| b == byte)
148 .map(|pos| pos as u8)
149 .ok_or(DecodeError::InvalidCharacter(ch))
150 }
151}
152
153#[derive(Debug)]
154pub enum DecodeError {
155 InvalidCharacter(char),
156 UnexpectedPadding,
157 InvalidPadding,
158}
159
160impl std::fmt::Display for DecodeError {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 DecodeError::InvalidCharacter(ch) => {
164 write!(f, "Invalid base64 character: '{}'", ch)
165 }
166 DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
167 DecodeError::InvalidPadding => {
168 write!(f, "Invalid base64 padding")
169 }
170 }
171 }
172}
173
174impl std::error::Error for DecodeError {}
175
176pub mod engine {
178 pub mod general_purpose {
179 use crate::util::base64::Engine;
180
181 pub const STANDARD: Engine = Engine::STANDARD;
182 pub const STANDARD_NO_PAD: Engine = Engine::STANDARD_NO_PAD;
183 pub const URL_SAFE_NO_PAD: Engine = Engine::URL_SAFE_NO_PAD;
184 }
185}
186
187#[cfg(test)]
188pub mod tests {
189 use super::*;
190
191 #[test]
192 fn test_encode_standard() {
193 assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
194 assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
195 assert_eq!(Engine::STANDARD.encode(b""), "");
196 }
197
198 #[test]
199 fn test_encode_no_pad() {
200 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
201 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
202 }
203
204 #[test]
205 fn test_decode_standard() {
206 assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
207 assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
208 assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
209 }
210
211 #[test]
212 fn test_roundtrip() {
213 let data = b"Hello, World! \x00\x01\x02\xFF";
214 let encoded = Engine::STANDARD.encode(data);
215 let decoded = Engine::STANDARD.decode(&encoded).unwrap();
216 assert_eq!(decoded, data);
217 }
218
219 #[test]
220 fn test_invalid_padding() {
221 assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
223 assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
224
225 assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
227
228 assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
230
231 assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
234
235 assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); assert!(Engine::STANDARD.decode("SGVs").is_ok()); }
240}