Skip to main content

nexcore_codec/
base64.rs

1//! Base64 encoding and decoding (RFC 4648 §4 + §5).
2//!
3//! Zero-dependency replacement for the `base64` crate.
4//!
5//! # Supply Chain Sovereignty
6//!
7//! This module has **zero external dependencies**. It replaces the `base64` crate
8//! for the `nexcore` ecosystem.
9//!
10//! # Alphabets
11//!
12//! - **Standard** (§4): `A-Z a-z 0-9 + /` with `=` padding
13//! - **URL-safe** (§5): `A-Z a-z 0-9 - _` with optional padding
14//!
15//! # Examples
16//!
17//! ```
18//! use nexcore_codec::base64;
19//!
20//! let encoded = base64::encode(b"Hello, World!");
21//! assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
22//!
23//! let decoded = base64::decode("SGVsbG8sIFdvcmxkIQ==").unwrap();
24//! assert_eq!(decoded, b"Hello, World!");
25//! ```
26
27/// Standard Base64 alphabet (RFC 4648 §4).
28const STANDARD: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
29
30/// URL-safe Base64 alphabet (RFC 4648 §5).
31const URL_SAFE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
32
33/// Encode bytes using standard Base64 with `=` padding.
34#[must_use]
35pub fn encode(input: impl AsRef<[u8]>) -> String {
36    encode_with_alphabet(input.as_ref(), STANDARD, true)
37}
38
39/// Decode a standard Base64 string (with or without padding).
40pub fn decode(input: impl AsRef<[u8]>) -> Result<Vec<u8>, DecodeError> {
41    decode_with_alphabet(input.as_ref(), false)
42}
43
44/// Encode bytes using URL-safe Base64 without padding.
45#[must_use]
46pub fn encode_url_safe_no_pad(input: impl AsRef<[u8]>) -> String {
47    encode_with_alphabet(input.as_ref(), URL_SAFE, false)
48}
49
50/// Decode a URL-safe Base64 string (without padding).
51pub fn decode_url_safe_no_pad(input: impl AsRef<[u8]>) -> Result<Vec<u8>, DecodeError> {
52    decode_with_alphabet(input.as_ref(), true)
53}
54
55/// Encode bytes using URL-safe Base64 with `=` padding.
56#[must_use]
57pub fn encode_url_safe(input: impl AsRef<[u8]>) -> String {
58    encode_with_alphabet(input.as_ref(), URL_SAFE, true)
59}
60
61/// Error returned when decoding an invalid Base64 string.
62#[non_exhaustive]
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum DecodeError {
65    /// Invalid character encountered.
66    InvalidChar { index: usize, byte: u8 },
67    /// Input length is invalid (not a multiple of 4 when padded).
68    InvalidLength,
69    /// Invalid padding.
70    InvalidPadding,
71}
72
73impl core::fmt::Display for DecodeError {
74    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75        match self {
76            Self::InvalidChar { index, byte } => {
77                write!(f, "invalid base64 char 0x{byte:02x} at index {index}")
78            }
79            Self::InvalidLength => write!(f, "invalid base64 length"),
80            Self::InvalidPadding => write!(f, "invalid base64 padding"),
81        }
82    }
83}
84
85impl std::error::Error for DecodeError {}
86
87/// Look up a byte in the alphabet table.
88///
89/// The index `idx` is always in 0..64 because it is derived from a `u32`
90/// masked with `0x3F` before calling this function.
91#[inline]
92fn alphabet_char(alphabet: &[u8; 64], idx: u32) -> char {
93    // SAFETY PROOF: `idx` is always `(n >> k) & 0x3F`, which is at most 63.
94    // The alphabet array has exactly 64 elements (indices 0..=63), so this
95    // index is always in bounds. All alphabet bytes are printable ASCII, so
96    // casting to `char` is always valid (all values < 128).
97    #[allow(
98        clippy::indexing_slicing,
99        reason = "idx is always (bits >> k) & 0x3F which is at most 63; alphabet has 64 elements"
100    )]
101    #[allow(
102        clippy::as_conversions,
103        reason = "alphabet bytes are ASCII (0-127); casting u8 to char is always valid here"
104    )]
105    (alphabet[idx as usize] as char)
106}
107
108fn encode_with_alphabet(input: &[u8], alphabet: &[u8; 64], pad: bool) -> String {
109    // Capacity: ceil(n / 3) * 4. Use saturating arithmetic — inputs large
110    // enough to overflow usize would OOM long before reaching this point.
111    let capacity = input
112        .len()
113        .saturating_add(2)
114        .checked_div(3)
115        .unwrap_or(0)
116        .saturating_mul(4);
117    let mut out = String::with_capacity(capacity);
118    let chunks = input.chunks_exact(3);
119    let remainder = chunks.remainder();
120
121    for chunk in chunks {
122        // chunks_exact(3) guarantees chunk has exactly 3 elements.
123        #[allow(
124            clippy::indexing_slicing,
125            reason = "chunks_exact(3) guarantees chunk.len() == 3; indices 0, 1, 2 are always valid"
126        )]
127        let n = (u32::from(chunk[0]) << 16) | (u32::from(chunk[1]) << 8) | u32::from(chunk[2]);
128        out.push(alphabet_char(alphabet, (n >> 18) & 0x3F));
129        out.push(alphabet_char(alphabet, (n >> 12) & 0x3F));
130        out.push(alphabet_char(alphabet, (n >> 6) & 0x3F));
131        out.push(alphabet_char(alphabet, n & 0x3F));
132    }
133
134    match remainder.len() {
135        1 => {
136            // chunks_exact remainder with len == 1: index 0 is valid.
137            #[allow(
138                clippy::indexing_slicing,
139                reason = "remainder.len() == 1 is proven by the match arm; index 0 is always valid"
140            )]
141            let n = u32::from(remainder[0]) << 16;
142            out.push(alphabet_char(alphabet, (n >> 18) & 0x3F));
143            out.push(alphabet_char(alphabet, (n >> 12) & 0x3F));
144            if pad {
145                out.push('=');
146                out.push('=');
147            }
148        }
149        2 => {
150            // chunks_exact remainder with len == 2: indices 0 and 1 are valid.
151            #[allow(
152                clippy::indexing_slicing,
153                reason = "remainder.len() == 2 is proven by the match arm; indices 0 and 1 are always valid"
154            )]
155            let n = (u32::from(remainder[0]) << 16) | (u32::from(remainder[1]) << 8);
156            out.push(alphabet_char(alphabet, (n >> 18) & 0x3F));
157            out.push(alphabet_char(alphabet, (n >> 12) & 0x3F));
158            out.push(alphabet_char(alphabet, (n >> 6) & 0x3F));
159            if pad {
160                out.push('=');
161            }
162        }
163        _ => {}
164    }
165
166    out
167}
168
169fn decode_with_alphabet(input: &[u8], url_safe: bool) -> Result<Vec<u8>, DecodeError> {
170    // Strip whitespace and padding
171    let input: Vec<u8> = input
172        .iter()
173        .copied()
174        .filter(|&b| b != b'\n' && b != b'\r' && b != b' ' && b != b'\t')
175        .collect();
176
177    // Strip trailing padding. `pad_count` is bounded by `input.len()` because
178    // `take_while` cannot yield more elements than the iterator contains.
179    let input_len = input.len();
180    let pad_count = input.iter().rev().take_while(|&&b| b == b'=').count();
181    // pad_count is produced by `take_while` on `input.iter()`, which cannot
182    // yield more elements than the iterator contains, so `pad_count <= input_len`.
183    // The subtraction therefore cannot underflow, and the slice is always in bounds.
184    let data_len = input_len.saturating_sub(pad_count);
185    // `data_len <= input_len` by construction; `.get(..)` returns `None` only
186    // if `data_len > input.len()`, which is impossible, so `unwrap_or` with an
187    // empty slice is the safe fallback that can never actually be reached.
188    let data: &[u8] = input.get(..data_len).unwrap_or(&[]);
189
190    if data.is_empty() {
191        return Ok(Vec::new());
192    }
193
194    // Validate length: data length mod 4 must not be 1 (would be incomplete group).
195    let mod4 = data.len() % 4;
196    if mod4 == 1 {
197        return Err(DecodeError::InvalidLength);
198    }
199
200    // Capacity: floor(n * 3 / 4). Use saturating to avoid overflow on huge
201    // inputs — such inputs would OOM before reaching this point.
202    let capacity = data.len().saturating_mul(3).checked_div(4).unwrap_or(0);
203    let mut out = Vec::with_capacity(capacity);
204    let chunks = data.chunks_exact(4);
205    let remainder = chunks.remainder();
206
207    for chunk in chunks {
208        // chunks_exact(4) guarantees chunk has exactly 4 elements.
209        #[allow(
210            clippy::indexing_slicing,
211            reason = "chunks_exact(4) guarantees chunk.len() == 4; indices 0-3 are always valid"
212        )]
213        {
214            let bits0 = decode_char(chunk[0], 0, url_safe)?;
215            let bits1 = decode_char(chunk[1], 1, url_safe)?;
216            let bits2 = decode_char(chunk[2], 2, url_safe)?;
217            let bits3 = decode_char(chunk[3], 3, url_safe)?;
218            let word = (u32::from(bits0) << 18)
219                | (u32::from(bits1) << 12)
220                | (u32::from(bits2) << 6)
221                | u32::from(bits3);
222            // Each shift extracts an 8-bit field from a 24-bit word; the
223            // truncating cast to u8 is the intended operation.
224            #[allow(
225                clippy::as_conversions,
226                reason = "extracting 8-bit fields from a 24-bit Base64 word; truncation is the correct semantic"
227            )]
228            {
229                out.push((word >> 16) as u8);
230                out.push((word >> 8) as u8);
231                out.push(word as u8);
232            }
233        }
234    }
235
236    match remainder.len() {
237        2 => {
238            // remainder.len() == 2: indices 0 and 1 are valid.
239            #[allow(
240                clippy::indexing_slicing,
241                reason = "remainder.len() == 2 is proven by the match arm; indices 0 and 1 are always valid"
242            )]
243            {
244                let bits0 = decode_char(remainder[0], 0, url_safe)?;
245                let bits1 = decode_char(remainder[1], 1, url_safe)?;
246                let word = (u32::from(bits0) << 18) | (u32::from(bits1) << 12);
247                // Extracting the top 8 bits of the 24-bit word.
248                #[allow(
249                    clippy::as_conversions,
250                    reason = "extracting 8-bit field from 24-bit Base64 word; truncation is the correct semantic"
251                )]
252                out.push((word >> 16) as u8);
253            }
254        }
255        3 => {
256            // remainder.len() == 3: indices 0, 1, and 2 are valid.
257            #[allow(
258                clippy::indexing_slicing,
259                reason = "remainder.len() == 3 is proven by the match arm; indices 0, 1, and 2 are always valid"
260            )]
261            {
262                let bits0 = decode_char(remainder[0], 0, url_safe)?;
263                let bits1 = decode_char(remainder[1], 1, url_safe)?;
264                let bits2 = decode_char(remainder[2], 2, url_safe)?;
265                let word =
266                    (u32::from(bits0) << 18) | (u32::from(bits1) << 12) | (u32::from(bits2) << 6);
267                // Extracting 8-bit fields from a 24-bit Base64 word.
268                #[allow(
269                    clippy::as_conversions,
270                    reason = "extracting 8-bit fields from 24-bit Base64 word; truncation is the correct semantic"
271                )]
272                {
273                    out.push((word >> 16) as u8);
274                    out.push((word >> 8) as u8);
275                }
276            }
277        }
278        _ => {}
279    }
280
281    Ok(out)
282}
283
284#[inline]
285fn decode_char(byte: u8, index: usize, url_safe: bool) -> Result<u8, DecodeError> {
286    match byte {
287        // Match arm guards prove the subtraction cannot underflow:
288        // b'A'..=b'Z' guarantees byte >= b'A', so byte - b'A' is in 0..=25.
289        // b'a'..=b'z' guarantees byte >= b'a', so byte - b'a' is in 0..=25;
290        //   adding 26 gives 26..=51, which fits in u8.
291        // b'0'..=b'9' guarantees byte >= b'0', so byte - b'0' is in 0..=9;
292        //   adding 52 gives 52..=61, which fits in u8.
293        #[allow(
294            clippy::arithmetic_side_effects,
295            reason = "match arm guards prove byte >= b'A'; subtraction cannot underflow; result fits in u8"
296        )]
297        b'A'..=b'Z' => Ok(byte - b'A'),
298        #[allow(
299            clippy::arithmetic_side_effects,
300            reason = "match arm guards prove byte >= b'a' and byte - b'a' <= 25; adding 26 gives at most 51, fitting in u8"
301        )]
302        b'a'..=b'z' => Ok(byte - b'a' + 26),
303        #[allow(
304            clippy::arithmetic_side_effects,
305            reason = "match arm guards prove byte >= b'0' and byte - b'0' <= 9; adding 52 gives at most 61, fitting in u8"
306        )]
307        b'0'..=b'9' => Ok(byte - b'0' + 52),
308        b'+' if !url_safe => Ok(62),
309        b'/' if !url_safe => Ok(63),
310        b'-' if url_safe => Ok(62),
311        b'_' if url_safe => Ok(63),
312        _ => Err(DecodeError::InvalidChar { index, byte }),
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    // RFC 4648 §10 test vectors
321    #[test]
322    fn rfc4648_test_vectors() {
323        let vectors = [
324            ("", ""),
325            ("f", "Zg=="),
326            ("fo", "Zm8="),
327            ("foo", "Zm9v"),
328            ("foob", "Zm9vYg=="),
329            ("fooba", "Zm9vYmE="),
330            ("foobar", "Zm9vYmFy"),
331        ];
332        for (input, expected) in vectors {
333            assert_eq!(encode(input.as_bytes()), expected, "encode({input:?})");
334            assert_eq!(
335                decode(expected).ok(),
336                Some(input.as_bytes().to_vec()),
337                "decode({expected:?})"
338            );
339        }
340    }
341
342    #[test]
343    fn encode_empty() {
344        assert_eq!(encode(b""), "");
345    }
346
347    #[test]
348    fn encode_hello_world() {
349        assert_eq!(encode(b"Hello, World!"), "SGVsbG8sIFdvcmxkIQ==");
350    }
351
352    #[test]
353    fn decode_hello_world() {
354        assert_eq!(
355            decode("SGVsbG8sIFdvcmxkIQ==").ok(),
356            Some(b"Hello, World!".to_vec())
357        );
358    }
359
360    #[test]
361    fn decode_without_padding() {
362        // Decoder should handle missing padding gracefully
363        assert_eq!(
364            decode("SGVsbG8sIFdvcmxkIQ").ok(),
365            Some(b"Hello, World!".to_vec())
366        );
367    }
368
369    #[test]
370    fn url_safe_encode() {
371        // Standard: uses + and /
372        let input = [0xFF, 0xFE, 0xFD];
373        let standard = encode(&input);
374        assert!(standard.contains('+') || standard.contains('/') || !standard.contains('-'));
375
376        // URL-safe: uses - and _
377        let url = encode_url_safe_no_pad(&input);
378        assert!(!url.contains('+'));
379        assert!(!url.contains('/'));
380        assert!(!url.contains('='));
381    }
382
383    #[test]
384    fn url_safe_roundtrip() {
385        let input = b"Hello, World! This is a test of URL-safe base64.";
386        let encoded = encode_url_safe_no_pad(input);
387        let decoded = decode_url_safe_no_pad(&encoded);
388        assert_eq!(decoded.ok(), Some(input.to_vec()));
389    }
390
391    #[test]
392    fn decode_invalid_char() {
393        let err = decode("!!!!");
394        assert!(matches!(err, Err(DecodeError::InvalidChar { .. })));
395    }
396
397    #[test]
398    fn decode_invalid_length() {
399        // Single char is invalid (mod 4 == 1)
400        let err = decode("A");
401        assert!(matches!(err, Err(DecodeError::InvalidLength)));
402    }
403
404    #[test]
405    fn roundtrip_all_byte_values() {
406        let input: Vec<u8> = (0..=255).collect();
407        let encoded = encode(&input);
408        let decoded = decode(&encoded);
409        assert_eq!(decoded.ok(), Some(input));
410    }
411
412    #[test]
413    fn roundtrip_various_lengths() {
414        for len in 0..=64_u8 {
415            let input: Vec<u8> = (0..len).collect();
416            let encoded = encode(&input);
417            let decoded = decode(&encoded);
418            assert_eq!(decoded.ok(), Some(input), "roundtrip failed for len={len}");
419        }
420    }
421
422    #[test]
423    fn decode_with_whitespace() {
424        let encoded = "SGVs\nbG8s\nIFdv\ncmxk\nIQ==";
425        assert_eq!(decode(encoded).ok(), Some(b"Hello, World!".to_vec()));
426    }
427}