Skip to main content

reifydb_type/util/
base64.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{error, fmt};
5const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
6const BASE64_URL_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
7
8pub struct Engine {
9	alphabet: &'static [u8],
10	use_padding: bool,
11}
12
13impl Engine {
14	pub const STANDARD: Engine = Engine {
15		alphabet: BASE64_CHARS,
16		use_padding: true,
17	};
18
19	pub const STANDARD_NO_PAD: Engine = Engine {
20		alphabet: BASE64_CHARS,
21		use_padding: false,
22	};
23
24	pub const URL_SAFE_NO_PAD: Engine = Engine {
25		alphabet: BASE64_URL_CHARS,
26		use_padding: false,
27	};
28
29	pub fn encode(&self, input: &[u8]) -> String {
30		if input.is_empty() {
31			return String::new();
32		}
33
34		let mut result = String::new();
35		let mut i = 0;
36
37		while i < input.len() {
38			let b1 = input[i];
39			let b2 = if i + 1 < input.len() {
40				input[i + 1]
41			} else {
42				0
43			};
44			let b3 = if i + 2 < input.len() {
45				input[i + 2]
46			} else {
47				0
48			};
49
50			let n = ((b1 as usize) << 16) | ((b2 as usize) << 8) | (b3 as usize);
51
52			result.push(self.alphabet[(n >> 18) & 63] as char);
53			result.push(self.alphabet[(n >> 12) & 63] as char);
54
55			if i + 1 < input.len() {
56				result.push(self.alphabet[(n >> 6) & 63] as char);
57				if i + 2 < input.len() {
58					result.push(self.alphabet[n & 63] as char);
59				} else if self.use_padding {
60					result.push('=');
61				}
62			} else if self.use_padding {
63				result.push('=');
64				result.push('=');
65			}
66
67			i += 3;
68		}
69
70		result
71	}
72
73	pub fn decode(&self, input: &str) -> Result<Vec<u8>, DecodeError> {
74		if !self.use_padding && input.contains('=') {
75			return Err(DecodeError::UnexpectedPadding);
76		}
77
78		if self.use_padding && input.contains('=') {
79			let padding_start = input.rfind(|c| c != '=').map(|i| i + 1).unwrap_or(0);
80			let padding_count = input.len() - padding_start;
81
82			if padding_count > 2 {
83				return Err(DecodeError::InvalidPadding);
84			}
85
86			if padding_start > 0 && input[..padding_start].contains('=') {
87				return Err(DecodeError::InvalidPadding);
88			}
89
90			if !input.len().is_multiple_of(4) {
91				return Err(DecodeError::InvalidPadding);
92			}
93
94			let non_padding_in_last_quantum = 4 - padding_count;
95			if non_padding_in_last_quantum < 2 {
96				return Err(DecodeError::InvalidPadding);
97			}
98		}
99
100		let input = input.trim_end_matches('=');
101		if input.is_empty() {
102			return Ok(Vec::new());
103		}
104
105		let mut result = Vec::new();
106		let mut accumulator = 0u32;
107		let mut bits_collected = 0;
108
109		for ch in input.chars() {
110			let value = self.char_to_value(ch)?;
111			accumulator = (accumulator << 6) | (value as u32);
112			bits_collected += 6;
113
114			if bits_collected >= 8 {
115				bits_collected -= 8;
116				result.push((accumulator >> bits_collected) as u8);
117				accumulator &= (1 << bits_collected) - 1;
118			}
119		}
120
121		Ok(result)
122	}
123
124	fn char_to_value(&self, ch: char) -> Result<u8, DecodeError> {
125		let byte = ch as u8;
126		self.alphabet
127			.iter()
128			.position(|&b| b == byte)
129			.map(|pos| pos as u8)
130			.ok_or(DecodeError::InvalidCharacter(ch))
131	}
132}
133
134#[derive(Debug)]
135pub enum DecodeError {
136	InvalidCharacter(char),
137	UnexpectedPadding,
138	InvalidPadding,
139}
140
141impl fmt::Display for DecodeError {
142	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143		match self {
144			DecodeError::InvalidCharacter(ch) => {
145				write!(f, "Invalid base64 character: '{}'", ch)
146			}
147			DecodeError::UnexpectedPadding => write!(f, "Unexpected padding in URL-safe base64"),
148			DecodeError::InvalidPadding => {
149				write!(f, "Invalid base64 padding")
150			}
151		}
152	}
153}
154
155impl error::Error for DecodeError {}
156
157pub mod engine {
158	pub mod general_purpose {
159		use crate::util::base64::Engine;
160
161		pub const STANDARD: Engine = Engine::STANDARD;
162		pub const STANDARD_NO_PAD: Engine = Engine::STANDARD_NO_PAD;
163		pub const URL_SAFE_NO_PAD: Engine = Engine::URL_SAFE_NO_PAD;
164	}
165}
166
167#[cfg(test)]
168pub mod tests {
169	use super::*;
170
171	#[test]
172	fn test_encode_standard() {
173		assert_eq!(Engine::STANDARD.encode(b"Hello"), "SGVsbG8=");
174		assert_eq!(Engine::STANDARD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
175		assert_eq!(Engine::STANDARD.encode(b""), "");
176	}
177
178	#[test]
179	fn test_encode_no_pad() {
180		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello"), "SGVsbG8");
181		assert_eq!(Engine::STANDARD_NO_PAD.encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ");
182	}
183
184	#[test]
185	fn test_decode_standard() {
186		assert_eq!(Engine::STANDARD.decode("SGVsbG8=").unwrap(), b"Hello");
187		assert_eq!(Engine::STANDARD.decode("SGVsbG8").unwrap(), b"Hello");
188		assert_eq!(Engine::STANDARD.decode("").unwrap(), b"");
189	}
190
191	#[test]
192	fn test_roundtrip() {
193		let data = b"Hello, World! \x00\x01\x02\xFF";
194		let encoded = Engine::STANDARD.encode(data);
195		let decoded = Engine::STANDARD.decode(&encoded).unwrap();
196		assert_eq!(decoded, data);
197	}
198
199	#[test]
200	fn test_invalid_padding() {
201		// Too many padding characters
202		assert!(Engine::STANDARD.decode("SGVsbG8===").is_err());
203		assert!(Engine::STANDARD.decode("SGVsbG8====").is_err());
204
205		// Padding in the middle
206		assert!(Engine::STANDARD.decode("SGVs=bG8=").is_err());
207
208		// Invalid length with padding (not divisible by 4)
209		assert!(Engine::STANDARD.decode("SGVsbG8=X").is_err());
210
211		// Invalid: "SGVsbG8=" is 8 chars, needs 1 padding char, but has
212		// 2
213		assert!(Engine::STANDARD.decode("SGVsbG8==").is_err());
214
215		// Valid padding should work
216		assert!(Engine::STANDARD.decode("SGVsbG8=").is_ok()); // "Hello" - needs 1 padding
217		assert!(Engine::STANDARD.decode("SGVsbA==").is_ok()); // "Hell" - needs 2 padding  
218		assert!(Engine::STANDARD.decode("SGVs").is_ok()); // "Hel" - no padding needed
219	}
220}