use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
Failed,
}
const LINE_LEN: usize = 64;
#[must_use]
pub fn encode(label: &str, der: &[u8]) -> alloc::string::String {
use core::fmt::Write;
let body = base64_encode(der);
let line_count = body.len().div_ceil(LINE_LEN);
let mut out =
alloc::string::String::with_capacity(body.len() + line_count + 2 * (label.len() + 16));
let _ = writeln!(out, "-----BEGIN {label}-----");
let mut start = 0;
while start < body.len() {
let end = (start + LINE_LEN).min(body.len());
out.push_str(&body[start..end]);
out.push('\n');
start = end;
}
let _ = writeln!(out, "-----END {label}-----");
out
}
pub fn decode(input: &str, expected_label: &str) -> Result<Vec<u8>, Error> {
let begin = alloc::format!("-----BEGIN {expected_label}-----");
let end = alloc::format!("-----END {expected_label}-----");
let begin_idx = input.find(&begin).ok_or(Error::Failed)?;
let after_begin = &input[begin_idx + begin.len()..];
let end_rel = after_begin.find(&end).ok_or(Error::Failed)?;
let body = &after_begin[..end_rel];
let mut stripped = alloc::string::String::with_capacity(body.len());
for ch in body.chars() {
if !ch.is_ascii_whitespace() {
stripped.push(ch);
}
}
base64_decode(&stripped).ok_or(Error::Failed)
}
const BASE64_ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
#[must_use]
fn base64_encode(input: &[u8]) -> alloc::string::String {
let mut out = alloc::string::String::with_capacity(input.len().div_ceil(3) * 4);
let mut i = 0;
while i + 3 <= input.len() {
let b0 = input[i];
let b1 = input[i + 1];
let b2 = input[i + 2];
out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(BASE64_ALPHABET[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
out.push(BASE64_ALPHABET[(b2 & 0x3F) as usize] as char);
i += 3;
}
let rem = input.len() - i;
if rem == 1 {
let b0 = input[i];
out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
out.push(BASE64_ALPHABET[((b0 & 0x03) << 4) as usize] as char);
out.push('=');
out.push('=');
} else if rem == 2 {
let b0 = input[i];
let b1 = input[i + 1];
out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(BASE64_ALPHABET[((b1 & 0x0F) << 2) as usize] as char);
out.push('=');
}
out
}
#[must_use]
fn base64_decode(input: &str) -> Option<Vec<u8>> {
let bytes = input.as_bytes();
if bytes.len() % 4 != 0 {
return None;
}
if bytes.is_empty() {
return Some(Vec::new());
}
let pad = if bytes.ends_with(b"==") {
2usize
} else {
usize::from(bytes.ends_with(b"="))
};
let body_chars = bytes.len() - pad;
let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
let mut i = 0;
while i + 4 <= bytes.len() {
let last_group = i + 4 == bytes.len();
let v0 = base64_lookup(bytes[i])?;
let v1 = base64_lookup(bytes[i + 1])?;
let (v2, v3) = if last_group {
(
if i + 2 < body_chars {
base64_lookup(bytes[i + 2])?
} else {
if bytes[i + 2] != b'=' {
return None;
}
0
},
if i + 3 < body_chars {
base64_lookup(bytes[i + 3])?
} else {
if bytes[i + 3] != b'=' {
return None;
}
0
},
)
} else {
(base64_lookup(bytes[i + 2])?, base64_lookup(bytes[i + 3])?)
};
let b0 = (v0 << 2) | (v1 >> 4);
let b1 = (v1 << 4) | (v2 >> 2);
let b2 = (v2 << 6) | v3;
if last_group {
if pad == 2 && (v1 & 0x0F) != 0 {
return None;
}
if pad == 1 && (v2 & 0x03) != 0 {
return None;
}
}
out.push(b0);
if !last_group || pad <= 1 {
out.push(b1);
}
if !last_group || pad == 0 {
out.push(b2);
}
i += 4;
}
Some(out)
}
const fn base64_lookup(c: u8) -> Option<u8> {
Some(match c {
b'A'..=b'Z' => c - b'A',
b'a'..=b'z' => c - b'a' + 26,
b'0'..=b'9' => c - b'0' + 52,
b'+' => 62,
b'/' => 63,
_ => return None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn base64_round_trip_empty() {
let bytes: &[u8] = &[];
assert_eq!(base64_encode(bytes), "");
assert_eq!(base64_decode("").as_deref(), Some(bytes));
}
#[test]
fn base64_round_trip_one_byte() {
assert_eq!(base64_encode(b"f"), "Zg==");
assert_eq!(base64_decode("Zg==").as_deref(), Some(b"f".as_slice()));
}
#[test]
fn base64_round_trip_two_bytes() {
assert_eq!(base64_encode(b"fo"), "Zm8=");
assert_eq!(base64_decode("Zm8=").as_deref(), Some(b"fo".as_slice()));
}
#[test]
fn base64_round_trip_three_bytes() {
assert_eq!(base64_encode(b"foo"), "Zm9v");
assert_eq!(base64_decode("Zm9v").as_deref(), Some(b"foo".as_slice()));
}
#[test]
fn base64_rfc4648_test_vectors() {
for (raw, encoded) in [
("", ""),
("f", "Zg=="),
("fo", "Zm8="),
("foo", "Zm9v"),
("foob", "Zm9vYg=="),
("fooba", "Zm9vYmE="),
("foobar", "Zm9vYmFy"),
] {
assert_eq!(base64_encode(raw.as_bytes()), encoded);
assert_eq!(
base64_decode(encoded).as_deref(),
Some(raw.as_bytes()),
"decode {encoded:?}"
);
}
}
#[test]
fn base64_decode_rejects_bad_chars() {
assert!(base64_decode("Zm9*").is_none()); assert!(base64_decode("Zm9").is_none()); assert!(base64_decode("Z===").is_none()); assert!(base64_decode("====").is_none()); }
#[test]
fn base64_decode_rejects_non_canonical_pad_bits() {
assert!(base64_decode("Zh==").is_none());
assert!(base64_decode("Zg==").is_some());
assert!(base64_decode("Zm8=").is_some());
assert!(base64_decode("Zm9=").is_none());
}
#[test]
fn pem_round_trip_short() {
let der: &[u8] = &[0x30, 0x03, 0x02, 0x01, 0x05];
let pem = encode("EC PRIVATE KEY", der);
let recovered = decode(&pem, "EC PRIVATE KEY").expect("decode");
assert_eq!(recovered, der);
}
#[test]
fn pem_round_trip_long_wraps_at_64() {
let der: alloc::vec::Vec<u8> = (0..100u8).collect();
let pem = encode("PRIVATE KEY", &der);
for line in pem.lines() {
if line.starts_with("---") {
continue;
}
assert!(line.len() <= LINE_LEN, "body line too long: {line:?}");
}
let recovered = decode(&pem, "PRIVATE KEY").expect("decode");
assert_eq!(recovered, der);
}
#[test]
fn pem_label_must_match() {
let pem = encode("PRIVATE KEY", b"\x30\x00");
assert!(matches!(decode(&pem, "PUBLIC KEY"), Err(Error::Failed)));
}
#[test]
fn pem_decode_rejects_missing_begin() {
assert!(matches!(
decode("garbage", "PRIVATE KEY"),
Err(Error::Failed)
));
}
#[test]
fn pem_decode_rejects_missing_end() {
let bad = "-----BEGIN PRIVATE KEY-----\nABCD\n";
assert!(matches!(decode(bad, "PRIVATE KEY"), Err(Error::Failed)));
}
#[test]
fn pem_decode_tolerates_crlf_and_extra_whitespace() {
let pem = "-----BEGIN PRIVATE KEY-----\r\n\
MAMC\r\n\
AQU=\r\n\
-----END PRIVATE KEY-----\r\n";
let recovered = decode(pem, "PRIVATE KEY").expect("decode");
assert_eq!(recovered, [0x30, 0x03, 0x02, 0x01, 0x05]);
}
#[test]
fn pem_encoded_form_is_strict() {
let der: alloc::vec::Vec<u8> = (0..200u8).collect();
let pem = encode("PRIVATE KEY", &der);
assert!(pem.ends_with('\n'));
assert!(!pem.contains('\r'));
assert!(pem.starts_with("-----BEGIN PRIVATE KEY-----\n"));
assert!(pem.contains("\n-----END PRIVATE KEY-----\n"));
}
}