Skip to main content

reifydb_type/util/
base64.rs

1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025 ReifyDB
3
4//! Simple base64 encoding/decoding implementation
5
6use std::{error, fmt};
7const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
8const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
9
10/// Base64 encoding engine
11pub struct Engine {
12	alphabet: &'static [u8],
13	use_padding: bool,
14}
15
16impl Engine {
17	/// Standard base64 with padding
18	pub const STANDARD: Engine = Engine {
19		alphabet: BASE64_CHARS,
20		use_padding: true,
21	};
22
23	/// Standard base64 without padding
24	pub const STANDARD_NO_PAD: Engine = Engine {
25		alphabet: BASE64_CHARS,
26		use_padding: false,
27	};
28
29	/// URL-safe base64 without padding
30	pub const URL_SAFE_NO_PAD: Engine = Engine {
31		alphabet: BASE64_URL_CHARS,
32		use_padding: false,
33	};
34
35	/// Encode bytes to base64 string
36	pub fn encode(&self, input: &[u8]) -> String {
37		if input.is_empty() {
38			return String::new();
39		}
40
41		let mut result = String::new();
42		let mut i = 0;
43
44		while i < input.len() {
45			let b1 = input[i];
46			let b2 = if i + 1 < input.len() {
47				input[i + 1]
48			} else {
49				0
50			};
51			let b3 = if i + 2 < input.len() {
52				input[i + 2]
53			} else {
54				0
55			};
56
57			let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
58
59			result.push(self.alphabet[(n >> 18) & 63] as char);
60			result.push(self.alphabet[(n >> 12) & 63] as char);
61
62			if i + 1 < input.len() {
63				result.push(self.alphabet[(n >> 6) & 63] as char);
64				if i + 2 < input.len() {
65					result.push(self.alphabet[n & 63] as char);
66				} else if self.use_padding {
67					result.push('=');
68				}
69			} else if self.use_padding {
70				result.push('=');
71				result.push('=');
72			}
73
74			i += 3;
75		}
76
77		result
78	}
79
80	/// Decode base64 string to bytes
81	pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
82		// URL-safe base64 should not have padding
83		if !self.use_padding && input.contains('=') {
84			return Err(DecodeError::UnexpectedPadding);
85		}
86
87		// Validate padding if present
88		if self.use_padding && input.contains('=') {
89			// Count trailing padding characters
90			let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
91			let padding_count = input.len() - padding_start;
92
93			// Valid base64 can only have 0, 1, or 2 padding
94			// characters
95			if padding_count > 2 {
96				return Err(DecodeError::InvalidPadding);
97			}
98
99			// Check that padding only appears at the end
100			if padding_start > 0 && input[..padding_start].contains('=') {
101				return Err(DecodeError::InvalidPadding);
102			}
103
104			// Total length must be divisible by 4
105			if input.len() % 4 != 0 {
106				return Err(DecodeError::InvalidPadding);
107			}
108
109			// Validate padding count based on the last quantum
110			// The last quantum (4 chars) can be:
111			// - XXXX (no padding)
112			// - XXX= (1 padding)
113			// - XX== (2 padding)
114			let non_padding_in_last_quantum = 4 - padding_count;
115			if non_padding_in_last_quantum < 2 {
116				return Err(DecodeError::InvalidPadding);
117			}
118		}
119
120		let input = input.trim_end_matches('=');
121		if input.is_empty() {
122			return Ok(Vec::new());
123		}
124
125		let mut result = Vec::new();
126		let mut accumulator = 0u32;
127		let mut bits_collected = 0;
128
129		for ch in input.chars() {
130			let value = self.char_to_value(ch)?;
131			accumulator = (accumulator << 6) | (value as u32);
132			bits_collected += 6;
133
134			if bits_collected >= 8 {
135				bits_collected -= 8;
136				result.push((accumulator >> bits_collected) as u8);
137				accumulator &= (1 << bits_collected) - 1;
138			}
139		}
140
141		Ok(result)
142	}
143
144	fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
145		let byte = ch as u8;
146		self.alphabet
147			.iter()
148			.position(|&b| b == byte)
149			.map(|pos| pos as u8)
150			.ok_or(DecodeError::InvalidCharacter(ch))
151	}
152}
153
154#[derive(Debug)]
155pub enum DecodeError {
156	InvalidCharacter(char),
157	UnexpectedPadding,
158	InvalidPadding,
159}
160
161impl fmt::Display for DecodeError {
162	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163		match self {
164			DecodeError::InvalidCharacter(ch) => {
165				write!(f, "Invalid base64 character: '{}'", ch)
166			}
167			DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
168			DecodeError::InvalidPadding => {
169				write!(f, "Invalid base64 padding")
170			}
171		}
172	}
173}
174
175impl error::Error for DecodeError {}
176
177// Convenience module to match the original API
178pub mod engine {
179	pub mod general_purpose {
180		use crate::util::base64::Engine;
181
182		pub const STANDARD: Engine = Engine::STANDARD;
183		pub const STANDARD_NO_PAD: Engine = Engine::STANDARD_NO_PAD;
184		pub const URL_SAFE_NO_PAD: Engine = Engine::URL_SAFE_NO_PAD;
185	}
186}
187
188#[cfg(test)]
189pub mod tests {
190	use super::*;
191
192	#[test]
193	fn test_encode_standard() {
194		assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
195		assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
196		assert_eq!(Engine::STANDARD.encode(b""), "");
197	}
198
199	#[test]
200	fn test_encode_no_pad() {
201		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
202		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
203	}
204
205	#[test]
206	fn test_decode_standard() {
207		assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
208		assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
209		assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
210	}
211
212	#[test]
213	fn test_roundtrip() {
214		let data = b"Hello, World! \x00\x01\x02\xFF";
215		let encoded = Engine::STANDARD.encode(data);
216		let decoded = Engine::STANDARD.decode(&encoded).unwrap();
217		assert_eq!(decoded, data);
218	}
219
220	#[test]
221	fn test_invalid_padding() {
222		// Too many padding characters
223		assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
224		assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
225
226		// Padding in the middle
227		assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
228
229		// Invalid length with padding (not divisible by 4)
230		assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
231
232		// Invalid: "SGVsbG8=" is 8 chars, needs 1 padding char, but has
233		// 2
234		assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
235
236		// Valid padding should work
237		assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); // "Hello" - needs 1 padding
238		assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); // "Hell" - needs 2 padding  
239		assert!(Engine::STANDARD.decode("SGVs").is_ok()); // "Hel" - no padding needed
240	}
241}