reifydb_type/util/
base64.rs1use std::{error, fmt};
5const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
6const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
7
8pub struct Engine {
10 alphabet: &'static [u8],
11 use_padding: bool,
12}
13
14impl Engine {
15 pub const STANDARD: Engine = Engine {
17 alphabet: BASE64_CHARS,
18 use_padding: true,
19 };
20
21 pub const STANDARD_NO_PAD: Engine = Engine {
23 alphabet: BASE64_CHARS,
24 use_padding: false,
25 };
26
27 pub const URL_SAFE_NO_PAD: Engine = Engine {
29 alphabet: BASE64_URL_CHARS,
30 use_padding: false,
31 };
32
33 pub fn encode(&self, input: &[u8]) -> String {
35 if input.is_empty() {
36 return String::new();
37 }
38
39 let mut result = String::new();
40 let mut i = 0;
41
42 while i < input.len() {
43 let b1 = input[i];
44 let b2 = if i + 1 < input.len() {
45 input[i + 1]
46 } else {
47 0
48 };
49 let b3 = if i + 2 < input.len() {
50 input[i + 2]
51 } else {
52 0
53 };
54
55 let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
56
57 result.push(self.alphabet[(n >> 18) & 63] as char);
58 result.push(self.alphabet[(n >> 12) & 63] as char);
59
60 if i + 1 < input.len() {
61 result.push(self.alphabet[(n >> 6) & 63] as char);
62 if i + 2 < input.len() {
63 result.push(self.alphabet[n & 63] as char);
64 } else if self.use_padding {
65 result.push('=');
66 }
67 } else if self.use_padding {
68 result.push('=');
69 result.push('=');
70 }
71
72 i += 3;
73 }
74
75 result
76 }
77
78 pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
80 if !self.use_padding && input.contains('=') {
82 return Err(DecodeError::UnexpectedPadding);
83 }
84
85 if self.use_padding && input.contains('=') {
87 let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
89 let padding_count = input.len() - padding_start;
90
91 if padding_count > 2 {
94 return Err(DecodeError::InvalidPadding);
95 }
96
97 if padding_start > 0 && input[..padding_start].contains('=') {
99 return Err(DecodeError::InvalidPadding);
100 }
101
102 if !input.len().is_multiple_of(4) {
104 return Err(DecodeError::InvalidPadding);
105 }
106
107 let non_padding_in_last_quantum = 4 - padding_count;
113 if non_padding_in_last_quantum < 2 {
114 return Err(DecodeError::InvalidPadding);
115 }
116 }
117
118 let input = input.trim_end_matches('=');
119 if input.is_empty() {
120 return Ok(Vec::new());
121 }
122
123 let mut result = Vec::new();
124 let mut accumulator = 0u32;
125 let mut bits_collected = 0;
126
127 for ch in input.chars() {
128 let value = self.char_to_value(ch)?;
129 accumulator = (accumulator << 6) | (value as u32);
130 bits_collected += 6;
131
132 if bits_collected >= 8 {
133 bits_collected -= 8;
134 result.push((accumulator >> bits_collected) as u8);
135 accumulator &= (1 << bits_collected) - 1;
136 }
137 }
138
139 Ok(result)
140 }
141
142 fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
143 let byte = ch as u8;
144 self.alphabet
145 .iter()
146 .position(|&b| b == byte)
147 .map(|pos| pos as u8)
148 .ok_or(DecodeError::InvalidCharacter(ch))
149 }
150}
151
152#[derive(Debug)]
153pub enum DecodeError {
154 InvalidCharacter(char),
155 UnexpectedPadding,
156 InvalidPadding,
157}
158
159impl fmt::Display for DecodeError {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 match self {
162 DecodeError::InvalidCharacter(ch) => {
163 write!(f, "Invalid base64 character: '{}'", ch)
164 }
165 DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
166 DecodeError::InvalidPadding => {
167 write!(f, "Invalid base64 padding")
168 }
169 }
170 }
171}
172
173impl error::Error for DecodeError {}
174
175pub mod engine {
177 pub mod general_purpose {
178 use crate::util::base64::Engine;
179
180 pub const STANDARD: Engine = Engine::STANDARD;
181 pub const STANDARD_NO_PAD: Engine = Engine::STANDARD_NO_PAD;
182 pub const URL_SAFE_NO_PAD: Engine = Engine::URL_SAFE_NO_PAD;
183 }
184}
185
186#[cfg(test)]
187pub mod tests {
188 use super::*;
189
190 #[test]
191 fn test_encode_standard() {
192 assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
193 assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
194 assert_eq!(Engine::STANDARD.encode(b""), "");
195 }
196
197 #[test]
198 fn test_encode_no_pad() {
199 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
200 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
201 }
202
203 #[test]
204 fn test_decode_standard() {
205 assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
206 assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
207 assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
208 }
209
210 #[test]
211 fn test_roundtrip() {
212 let data = b"Hello, World! \x00\x01\x02\xFF";
213 let encoded = Engine::STANDARD.encode(data);
214 let decoded = Engine::STANDARD.decode(&encoded).unwrap();
215 assert_eq!(decoded, data);
216 }
217
218 #[test]
219 fn test_invalid_padding() {
220 assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
222 assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
223
224 assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
226
227 assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
229
230 assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
233
234 assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); assert!(Engine::STANDARD.decode("SGVs").is_ok()); }
239}