reifydb_type/util/
base64.rs1use std::{error, fmt};
5const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
6const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
7
8pub struct Engine {
9 alphabet: &'static [u8],
10 use_padding: bool,
11}
12
13impl Engine {
14 pub const STANDARD: Engine = Engine {
15 alphabet: BASE64_CHARS,
16 use_padding: true,
17 };
18
19 pub const STANDARD_NO_PAD: Engine = Engine {
20 alphabet: BASE64_CHARS,
21 use_padding: false,
22 };
23
24 pub const URL_SAFE_NO_PAD: Engine = Engine {
25 alphabet: BASE64_URL_CHARS,
26 use_padding: false,
27 };
28
29 pub fn encode(&self, input: &[u8]) -> String {
30 if input.is_empty() {
31 return String::new();
32 }
33
34 let mut result = String::new();
35 let mut i = 0;
36
37 while i < input.len() {
38 let b1 = input[i];
39 let b2 = if i + 1 < input.len() {
40 input[i + 1]
41 } else {
42 0
43 };
44 let b3 = if i + 2 < input.len() {
45 input[i + 2]
46 } else {
47 0
48 };
49
50 let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
51
52 result.push(self.alphabet[(n >> 18) & 63] as char);
53 result.push(self.alphabet[(n >> 12) & 63] as char);
54
55 if i + 1 < input.len() {
56 result.push(self.alphabet[(n >> 6) & 63] as char);
57 if i + 2 < input.len() {
58 result.push(self.alphabet[n & 63] as char);
59 } else if self.use_padding {
60 result.push('=');
61 }
62 } else if self.use_padding {
63 result.push('=');
64 result.push('=');
65 }
66
67 i += 3;
68 }
69
70 result
71 }
72
73 pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
74 if !self.use_padding && input.contains('=') {
75 return Err(DecodeError::UnexpectedPadding);
76 }
77
78 if self.use_padding && input.contains('=') {
79 let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
80 let padding_count = input.len() - padding_start;
81
82 if padding_count > 2 {
83 return Err(DecodeError::InvalidPadding);
84 }
85
86 if padding_start > 0 && input[..padding_start].contains('=') {
87 return Err(DecodeError::InvalidPadding);
88 }
89
90 if !input.len().is_multiple_of(4) {
91 return Err(DecodeError::InvalidPadding);
92 }
93
94 let non_padding_in_last_quantum = 4 - padding_count;
95 if non_padding_in_last_quantum < 2 {
96 return Err(DecodeError::InvalidPadding);
97 }
98 }
99
100 let input = input.trim_end_matches('=');
101 if input.is_empty() {
102 return Ok(Vec::new());
103 }
104
105 let mut result = Vec::new();
106 let mut accumulator = 0u32;
107 let mut bits_collected = 0;
108
109 for ch in input.chars() {
110 let value = self.char_to_value(ch)?;
111 accumulator = (accumulator << 6) | (value as u32);
112 bits_collected += 6;
113
114 if bits_collected >= 8 {
115 bits_collected -= 8;
116 result.push((accumulator >> bits_collected) as u8);
117 accumulator &= (1 << bits_collected) - 1;
118 }
119 }
120
121 Ok(result)
122 }
123
124 fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
125 let byte = ch as u8;
126 self.alphabet
127 .iter()
128 .position(|&b| b == byte)
129 .map(|pos| pos as u8)
130 .ok_or(DecodeError::InvalidCharacter(ch))
131 }
132}
133
134#[derive(Debug)]
135pub enum DecodeError {
136 InvalidCharacter(char),
137 UnexpectedPadding,
138 InvalidPadding,
139}
140
141impl fmt::Display for DecodeError {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 match self {
144 DecodeError::InvalidCharacter(ch) => {
145 write!(f, "Invalid base64 character: '{}'", ch)
146 }
147 DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
148 DecodeError::InvalidPadding => {
149 write!(f, "Invalid base64 padding")
150 }
151 }
152 }
153}
154
155impl error::Error for DecodeError {}
156
157pub mod engine {
158 pub mod general_purpose {
159 use crate::util::base64::Engine;
160
161 pub const STANDARD: Engine = Engine::STANDARD;
162 pub const STANDARD_NO_PAD: Engine = Engine::STANDARD_NO_PAD;
163 pub const URL_SAFE_NO_PAD: Engine = Engine::URL_SAFE_NO_PAD;
164 }
165}
166
167#[cfg(test)]
168pub mod tests {
169 use super::*;
170
171 #[test]
172 fn test_encode_standard() {
173 assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
174 assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
175 assert_eq!(Engine::STANDARD.encode(b""), "");
176 }
177
178 #[test]
179 fn test_encode_no_pad() {
180 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
181 assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
182 }
183
184 #[test]
185 fn test_decode_standard() {
186 assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
187 assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
188 assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
189 }
190
191 #[test]
192 fn test_roundtrip() {
193 let data = b"Hello, World! \x00\x01\x02\xFF";
194 let encoded = Engine::STANDARD.encode(data);
195 let decoded = Engine::STANDARD.decode(&encoded).unwrap();
196 assert_eq!(decoded, data);
197 }
198
199 #[test]
200 fn test_invalid_padding() {
201 assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
203 assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
204
205 assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
207
208 assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
210
211 assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
214
215 assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); assert!(Engine::STANDARD.decode("SGVs").is_ok()); }
220}