reifydb_type/util/
base64.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the MIT, see license.md file
3
4//! Simple base64 encoding/decoding implementation
5
6const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
7const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
8
9/// Base64 encoding engine
10pub struct Engine {
11	alphabet: &'static [u8],
12	use_padding: bool,
13}
14
15impl Engine {
16	/// Standard base64 with padding
17	pub const STANDARD: Engine = Engine {
18		alphabet: BASE64_CHARS,
19		use_padding: true,
20	};
21
22	/// Standard base64 without padding
23	pub const STANDARD_NO_PAD: Engine = Engine {
24		alphabet: BASE64_CHARS,
25		use_padding: false,
26	};
27
28	/// URL-safe base64 without padding
29	pub const URL_SAFE_NO_PAD: Engine = Engine {
30		alphabet: BASE64_URL_CHARS,
31		use_padding: false,
32	};
33
34	/// Encode bytes to base64 string
35	pub fn encode(&self, input: &[u8]) -> String {
36		if input.is_empty() {
37			return String::new();
38		}
39
40		let mut result = String::new();
41		let mut i = 0;
42
43		while i < input.len() {
44			let b1 = input[i];
45			let b2 = if i + 1 < input.len() {
46				input[i + 1]
47			} else {
48				0
49			};
50			let b3 = if i + 2 < input.len() {
51				input[i + 2]
52			} else {
53				0
54			};
55
56			let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
57
58			result.push(self.alphabet[(n >> 18) & 63] as char);
59			result.push(self.alphabet[(n >> 12) & 63] as char);
60
61			if i + 1 < input.len() {
62				result.push(self.alphabet[(n >> 6) & 63] as char);
63				if i + 2 < input.len() {
64					result.push(self.alphabet[n & 63] as char);
65				} else if self.use_padding {
66					result.push('=');
67				}
68			} else if self.use_padding {
69				result.push('=');
70				result.push('=');
71			}
72
73			i += 3;
74		}
75
76		result
77	}
78
79	/// Decode base64 string to bytes
80	pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
81		// URL-safe base64 should not have padding
82		if !self.use_padding && input.contains('=') {
83			return Err(DecodeError::UnexpectedPadding);
84		}
85
86		// Validate padding if present
87		if self.use_padding && input.contains('=') {
88			// Count trailing padding characters
89			let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
90			let padding_count = input.len() - padding_start;
91
92			// Valid base64 can only have 0, 1, or 2 padding
93			// characters
94			if padding_count > 2 {
95				return Err(DecodeError::InvalidPadding);
96			}
97
98			// Check that padding only appears at the end
99			if padding_start > 0 && input[..padding_start].contains('=') {
100				return Err(DecodeError::InvalidPadding);
101			}
102
103			// Total length must be divisible by 4
104			if input.len() % 4 != 0 {
105				return Err(DecodeError::InvalidPadding);
106			}
107
108			// Validate padding count based on the last quantum
109			// The last quantum (4 chars) can be:
110			// - XXXX (no padding)
111			// - XXX= (1 padding)
112			// - XX== (2 padding)
113			let non_padding_in_last_quantum = 4 - padding_count;
114			if non_padding_in_last_quantum < 2 {
115				return Err(DecodeError::InvalidPadding);
116			}
117		}
118
119		let input = input.trim_end_matches('=');
120		if input.is_empty() {
121			return Ok(Vec::new());
122		}
123
124		let mut result = Vec::new();
125		let mut accumulator = 0u32;
126		let mut bits_collected = 0;
127
128		for ch in input.chars() {
129			let value = self.char_to_value(ch)?;
130			accumulator = (accumulator << 6) | (value as u32);
131			bits_collected += 6;
132
133			if bits_collected >= 8 {
134				bits_collected -= 8;
135				result.push((accumulator >> bits_collected) as u8);
136				accumulator &= (1 << bits_collected) - 1;
137			}
138		}
139
140		Ok(result)
141	}
142
143	fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
144		let byte = ch as u8;
145		self.alphabet
146			.iter()
147			.position(|&b| b == byte)
148			.map(|pos| pos as u8)
149			.ok_or(DecodeError::InvalidCharacter(ch))
150	}
151}
152
153#[derive(Debug)]
154pub enum DecodeError {
155	InvalidCharacter(char),
156	UnexpectedPadding,
157	InvalidPadding,
158}
159
160impl std::fmt::Display for DecodeError {
161	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162		match self {
163			DecodeError::InvalidCharacter(ch) => {
164				write!(f, "Invalid base64 character: '{}'", ch)
165			}
166			DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
167			DecodeError::InvalidPadding => {
168				write!(f, "Invalid base64 padding")
169			}
170		}
171	}
172}
173
174impl std::error::Error for DecodeError {}
175
176// Convenience module to match the original API
177pub mod engine {
178	pub mod general_purpose {
179		pub const STANDARD: super::super::Engine = super::super::Engine::STANDARD;
180		pub const STANDARD_NO_PAD: super::super::Engine = super::super::Engine::STANDARD_NO_PAD;
181		pub const URL_SAFE_NO_PAD: super::super::Engine = super::super::Engine::URL_SAFE_NO_PAD;
182	}
183}
184
185#[cfg(test)]
186mod tests {
187	use super::*;
188
189	#[test]
190	fn test_encode_standard() {
191		assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
192		assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
193		assert_eq!(Engine::STANDARD.encode(b""), "");
194	}
195
196	#[test]
197	fn test_encode_no_pad() {
198		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
199		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
200	}
201
202	#[test]
203	fn test_decode_standard() {
204		assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
205		assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
206		assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
207	}
208
209	#[test]
210	fn test_roundtrip() {
211		let data = b"Hello, World! \x00\x01\x02\xFF";
212		let encoded = Engine::STANDARD.encode(data);
213		let decoded = Engine::STANDARD.decode(&encoded).unwrap();
214		assert_eq!(decoded, data);
215	}
216
217	#[test]
218	fn test_invalid_padding() {
219		// Too many padding characters
220		assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
221		assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
222
223		// Padding in the middle
224		assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
225
226		// Invalid length with padding (not divisible by 4)
227		assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
228
229		// Invalid: "SGVsbG8=" is 8 chars, needs 1 padding char, but has
230		// 2
231		assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
232
233		// Valid padding should work
234		assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); // "Hello" - needs 1 padding
235		assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); // "Hell" - needs 2 padding  
236		assert!(Engine::STANDARD.decode("SGVs").is_ok()); // "Hel" - no padding needed
237	}
238}