Skip to main content

reifydb_type/util/
base64.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{error, fmt};
5const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
6const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
7
8/// Base64 encoding engine
9pub struct Engine {
10	alphabet: &'static [u8],
11	use_padding: bool,
12}
13
14impl Engine {
15	/// Standard base64 with padding
16	pub const STANDARD: Engine = Engine {
17		alphabet: BASE64_CHARS,
18		use_padding: true,
19	};
20
21	/// Standard base64 without padding
22	pub const STANDARD_NO_PAD: Engine = Engine {
23		alphabet: BASE64_CHARS,
24		use_padding: false,
25	};
26
27	/// URL-safe base64 without padding
28	pub const URL_SAFE_NO_PAD: Engine = Engine {
29		alphabet: BASE64_URL_CHARS,
30		use_padding: false,
31	};
32
33	/// Encode bytes to base64 string
34	pub fn encode(&self, input: &[u8]) -> String {
35		if input.is_empty() {
36			return String::new();
37		}
38
39		let mut result = String::new();
40		let mut i = 0;
41
42		while i < input.len() {
43			let b1 = input[i];
44			let b2 = if i + 1 < input.len() {
45				input[i + 1]
46			} else {
47				0
48			};
49			let b3 = if i + 2 < input.len() {
50				input[i + 2]
51			} else {
52				0
53			};
54
55			let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
56
57			result.push(self.alphabet[(n >> 18) & 63] as char);
58			result.push(self.alphabet[(n >> 12) & 63] as char);
59
60			if i + 1 < input.len() {
61				result.push(self.alphabet[(n >> 6) & 63] as char);
62				if i + 2 < input.len() {
63					result.push(self.alphabet[n & 63] as char);
64				} else if self.use_padding {
65					result.push('=');
66				}
67			} else if self.use_padding {
68				result.push('=');
69				result.push('=');
70			}
71
72			i += 3;
73		}
74
75		result
76	}
77
78	/// Decode base64 string to bytes
79	pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
80		// URL-safe base64 should not have padding
81		if !self.use_padding && input.contains('=') {
82			return Err(DecodeError::UnexpectedPadding);
83		}
84
85		// Validate padding if present
86		if self.use_padding && input.contains('=') {
87			// Count trailing padding characters
88			let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
89			let padding_count = input.len() - padding_start;
90
91			// Valid base64 can only have 0, 1, or 2 padding
92			// characters
93			if padding_count > 2 {
94				return Err(DecodeError::InvalidPadding);
95			}
96
97			// Check that padding only appears at the end
98			if padding_start > 0 && input[..padding_start].contains('=') {
99				return Err(DecodeError::InvalidPadding);
100			}
101
102			// Total length must be divisible by 4
103			if !input.len().is_multiple_of(4) {
104				return Err(DecodeError::InvalidPadding);
105			}
106
107			// Validate padding count based on the last quantum
108			// The last quantum (4 chars) can be:
109			// - XXXX (no padding)
110			// - XXX= (1 padding)
111			// - XX== (2 padding)
112			let non_padding_in_last_quantum = 4 - padding_count;
113			if non_padding_in_last_quantum < 2 {
114				return Err(DecodeError::InvalidPadding);
115			}
116		}
117
118		let input = input.trim_end_matches('=');
119		if input.is_empty() {
120			return Ok(Vec::new());
121		}
122
123		let mut result = Vec::new();
124		let mut accumulator = 0u32;
125		let mut bits_collected = 0;
126
127		for ch in input.chars() {
128			let value = self.char_to_value(ch)?;
129			accumulator = (accumulator << 6) | (value as u32);
130			bits_collected += 6;
131
132			if bits_collected >= 8 {
133				bits_collected -= 8;
134				result.push((accumulator >> bits_collected) as u8);
135				accumulator &= (1 << bits_collected) - 1;
136			}
137		}
138
139		Ok(result)
140	}
141
142	fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
143		let byte = ch as u8;
144		self.alphabet
145			.iter()
146			.position(|&b| b == byte)
147			.map(|pos| pos as u8)
148			.ok_or(DecodeError::InvalidCharacter(ch))
149	}
150}
151
152#[derive(Debug)]
153pub enum DecodeError {
154	InvalidCharacter(char),
155	UnexpectedPadding,
156	InvalidPadding,
157}
158
159impl fmt::Display for DecodeError {
160	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161		match self {
162			DecodeError::InvalidCharacter(ch) => {
163				write!(f, "Invalid base64 character: '{}'", ch)
164			}
165			DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
166			DecodeError::InvalidPadding => {
167				write!(f, "Invalid base64 padding")
168			}
169		}
170	}
171}
172
173impl error::Error for DecodeError {}
174
175// Convenience module to match the original API
176pub mod engine {
177	pub mod general_purpose {
178		use crate::util::base64::Engine;
179
180		pub const STANDARD: Engine = Engine::STANDARD;
181		pub const STANDARD_NO_PAD: Engine = Engine::STANDARD_NO_PAD;
182		pub const URL_SAFE_NO_PAD: Engine = Engine::URL_SAFE_NO_PAD;
183	}
184}
185
186#[cfg(test)]
187pub mod tests {
188	use super::*;
189
190	#[test]
191	fn test_encode_standard() {
192		assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
193		assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
194		assert_eq!(Engine::STANDARD.encode(b""), "");
195	}
196
197	#[test]
198	fn test_encode_no_pad() {
199		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
200		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
201	}
202
203	#[test]
204	fn test_decode_standard() {
205		assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
206		assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
207		assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
208	}
209
210	#[test]
211	fn test_roundtrip() {
212		let data = b"Hello, World! \x00\x01\x02\xFF";
213		let encoded = Engine::STANDARD.encode(data);
214		let decoded = Engine::STANDARD.decode(&encoded).unwrap();
215		assert_eq!(decoded, data);
216	}
217
218	#[test]
219	fn test_invalid_padding() {
220		// Too many padding characters
221		assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
222		assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
223
224		// Padding in the middle
225		assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
226
227		// Invalid length with padding (not divisible by 4)
228		assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
229
230		// Invalid: "SGVsbG8=" is 8 chars, needs 1 padding char, but has
231		// 2
232		assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
233
234		// Valid padding should work
235		assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); // "Hello" - needs 1 padding
236		assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); // "Hell" - needs 2 padding  
237		assert!(Engine::STANDARD.decode("SGVs").is_ok()); // "Hel" - no padding needed
238	}
239}