reifydb_type/util/
base64.rs1const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
7const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
8
9pub struct Engine {
11 alphabet: &'static [u8],
12 use_padding: bool,
13}
14
15impl Engine {
16 pub const STANDARD: Engine = Engine {
18 alphabet: BASE64_CHARS,
19 use_padding: true,
20 };
21
22 pub const STANDARD_NO_PAD: Engine = Engine {
24 alphabet: BASE64_CHARS,
25 use_padding: false,
26 };
27
28 pub const URL_SAFE_NO_PAD: Engine = Engine {
30 alphabet: BASE64_URL_CHARS,
31 use_padding: false,
32 };
33
34 pub fn encode(&self, input: &[u8]) -> String {
36 if input.is_empty() {
37 return String::new();
38 }
39
40 let mut result = String::new();
41 let mut i = 0;
42
43 while i < input.len() {
44 let b1 = input[i];
45 let b2 = if i + 1 < input.len() {
46 input[i + 1]
47 } else {
48 0
49 };
50 let b3 = if i + 2 < input.len() {
51 input[i + 2]
52 } else {
53 0
54 };
55
56 let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
57
58 result.push(self.alphabet[(n >> 18) & 63] as char);
59 result.push(self.alphabet[(n >> 12) & 63] as char);
60
61 if i + 1 < input.len() {
62 result.push(self.alphabet[(n >> 6) & 63] as char);
63 if i + 2 < input.len() {
64 result.push(self.alphabet[n & 63] as char);
65 } else if self.use_padding {
66 result.push('=');
67 }
68 } else if self.use_padding {
69 result.push('=');
70 result.push('=');
71 }
72
73 i += 3;
74 }
75
76 result
77 }
78
79 pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
81 if !self.use_padding && input.contains('=') {
83 return Err(DecodeError::UnexpectedPadding);
84 }
85
86 if self.use_padding && input.contains('=') {
88 let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
90 let padding_count = input.len() - padding_start;
91
92 if padding_count > 2 {
95 return Err(DecodeError::InvalidPadding);
96 }
97
98 if padding_start > 0 && input[..padding_start].contains('=') {
100 return Err(DecodeError::InvalidPadding);
101 }
102
103 if input.len() % 4 != 0 {
105 return Err(DecodeError::InvalidPadding);
106 }
107
108 let non_padding_in_last_quantum = 4 - padding_count;
114 if non_padding_in_last_quantum < 2 {
115 return Err(DecodeError::InvalidPadding);
116 }
117 }
118
119 let input = input.trim_end_matches('=');
120 if input.is_empty() {
121 return Ok(Vec::new());
122 }
123
124 let mut result = Vec::new();
125 let mut accumulator = 0u32;
126 let mut bits_collected = 0;
127
128 for ch in input.chars() {
129 let value = self.char_to_value(ch)?;
130 accumulator = (accumulator << 6) | (value as u32);
131 bits_collected += 6;
132
133 if bits_collected >= 8 {
134 bits_collected -= 8;
135 result.push((accumulator >> bits_collected) as u8);
136 accumulator &= (1 << bits_collected) - 1;
137 }
138 }
139
140 Ok(result)
141 }
142
143 fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
144 let byte = ch as u8;
145 self.alphabet
146 .iter()
147 .position(|&b| b == byte)
148 .map(|pos| pos as u8)
149 .ok_or(DecodeError::InvalidCharacter(ch))
150 }
151}
152
153#[derive(Debug)]
154pub enum DecodeError {
155 InvalidCharacter(char),
156 UnexpectedPadding,
157 InvalidPadding,
158}
159
160impl std::fmt::Display for DecodeError {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 DecodeError::InvalidCharacter(ch) => {
164 write!(f, "Invalid base64 character: '{}'", ch)
165 }
166 DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
167 DecodeError::InvalidPadding => {
168 write!(f, "Invalid base64 padding")
169 }
170 }
171 }
172}
173
174impl std::error::Error for DecodeError {}
175
176pub mod engine {
178 pub mod general_purpose {
179 pub const STANDARD: super::super::Engine = super::super::Engine::STANDARD;
180 pub const STANDARD_NO_PAD: super::super::Engine = super::super::Engine::STANDARD_NO_PAD;
181 pub const URL_SAFE_NO_PAD: super::super::Engine = super::super::Engine::URL_SAFE_NO_PAD;
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn test_encode_standard() {
191 assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
192 assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
193 assert_eq!(Engine::STANDARD.encode(b""), "");
194 }
195
196 #[test]
197 fn test_encode_no_pad() {
198 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
199 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
200 }
201
202 #[test]
203 fn test_decode_standard() {
204 assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
205 assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
206 assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
207 }
208
209 #[test]
210 fn test_roundtrip() {
211 let data = b"Hello, World! \x00\x01\x02\xFF";
212 let encoded = Engine::STANDARD.encode(data);
213 let decoded = Engine::STANDARD.decode(&encoded).unwrap();
214 assert_eq!(decoded, data);
215 }
216
217 #[test]
218 fn test_invalid_padding() {
219 assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
221 assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
222
223 assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
225
226 assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
228
229 assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
232
233 assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); assert!(Engine::STANDARD.decode("SGVs").is_ok()); }
238}