reifydb_type/util/
base64.rs1use std::{error, fmt};
7const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
8const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
9
10pub struct Engine {
12 alphabet: &'static [u8],
13 use_padding: bool,
14}
15
16impl Engine {
17 pub const STANDARD: Engine = Engine {
19 alphabet: BASE64_CHARS,
20 use_padding: true,
21 };
22
23 pub const STANDARD_NO_PAD: Engine = Engine {
25 alphabet: BASE64_CHARS,
26 use_padding: false,
27 };
28
29 pub const URL_SAFE_NO_PAD: Engine = Engine {
31 alphabet: BASE64_URL_CHARS,
32 use_padding: false,
33 };
34
35 pub fn encode(&self, input: &[u8]) -> String {
37 if input.is_empty() {
38 return String::new();
39 }
40
41 let mut result = String::new();
42 let mut i = 0;
43
44 while i < input.len() {
45 let b1 = input[i];
46 let b2 = if i + 1 < input.len() {
47 input[i + 1]
48 } else {
49 0
50 };
51 let b3 = if i + 2 < input.len() {
52 input[i + 2]
53 } else {
54 0
55 };
56
57 let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
58
59 result.push(self.alphabet[(n >> 18) & 63] as char);
60 result.push(self.alphabet[(n >> 12) & 63] as char);
61
62 if i + 1 < input.len() {
63 result.push(self.alphabet[(n >> 6) & 63] as char);
64 if i + 2 < input.len() {
65 result.push(self.alphabet[n & 63] as char);
66 } else if self.use_padding {
67 result.push('=');
68 }
69 } else if self.use_padding {
70 result.push('=');
71 result.push('=');
72 }
73
74 i += 3;
75 }
76
77 result
78 }
79
80 pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
82 if !self.use_padding && input.contains('=') {
84 return Err(DecodeError::UnexpectedPadding);
85 }
86
87 if self.use_padding && input.contains('=') {
89 let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
91 let padding_count = input.len() - padding_start;
92
93 if padding_count > 2 {
96 return Err(DecodeError::InvalidPadding);
97 }
98
99 if padding_start > 0 && input[..padding_start].contains('=') {
101 return Err(DecodeError::InvalidPadding);
102 }
103
104 if input.len() % 4 != 0 {
106 return Err(DecodeError::InvalidPadding);
107 }
108
109 let non_padding_in_last_quantum = 4 - padding_count;
115 if non_padding_in_last_quantum < 2 {
116 return Err(DecodeError::InvalidPadding);
117 }
118 }
119
120 let input = input.trim_end_matches('=');
121 if input.is_empty() {
122 return Ok(Vec::new());
123 }
124
125 let mut result = Vec::new();
126 let mut accumulator = 0u32;
127 let mut bits_collected = 0;
128
129 for ch in input.chars() {
130 let value = self.char_to_value(ch)?;
131 accumulator = (accumulator << 6) | (value as u32);
132 bits_collected += 6;
133
134 if bits_collected >= 8 {
135 bits_collected -= 8;
136 result.push((accumulator >> bits_collected) as u8);
137 accumulator &= (1 << bits_collected) - 1;
138 }
139 }
140
141 Ok(result)
142 }
143
144 fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
145 let byte = ch as u8;
146 self.alphabet
147 .iter()
148 .position(|&b| b == byte)
149 .map(|pos| pos as u8)
150 .ok_or(DecodeError::InvalidCharacter(ch))
151 }
152}
153
154#[derive(Debug)]
155pub enum DecodeError {
156 InvalidCharacter(char),
157 UnexpectedPadding,
158 InvalidPadding,
159}
160
161impl fmt::Display for DecodeError {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 match self {
164 DecodeError::InvalidCharacter(ch) => {
165 write!(f, "Invalid base64 character: '{}'", ch)
166 }
167 DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
168 DecodeError::InvalidPadding => {
169 write!(f, "Invalid base64 padding")
170 }
171 }
172 }
173}
174
175impl error::Error for DecodeError {}
176
177pub mod engine {
179 pub mod general_purpose {
180 use crate::util::base64::Engine;
181
182 pub const STANDARD: Engine = Engine::STANDARD;
183 pub const STANDARD_NO_PAD: Engine = Engine::STANDARD_NO_PAD;
184 pub const URL_SAFE_NO_PAD: Engine = Engine::URL_SAFE_NO_PAD;
185 }
186}
187
188#[cfg(test)]
189pub mod tests {
190 use super::*;
191
192 #[test]
193 fn test_encode_standard() {
194 assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
195 assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
196 assert_eq!(Engine::STANDARD.encode(b""), "");
197 }
198
199 #[test]
200 fn test_encode_no_pad() {
201 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
202 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
203 }
204
205 #[test]
206 fn test_decode_standard() {
207 assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
208 assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
209 assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
210 }
211
212 #[test]
213 fn test_roundtrip() {
214 let data = b"Hello, World! \x00\x01\x02\xFF";
215 let encoded = Engine::STANDARD.encode(data);
216 let decoded = Engine::STANDARD.decode(&encoded).unwrap();
217 assert_eq!(decoded, data);
218 }
219
220 #[test]
221 fn test_invalid_padding() {
222 assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
224 assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
225
226 assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
228
229 assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
231
232 assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
235
236 assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); assert!(Engine::STANDARD.decode("SGVs").is_ok()); }
241}