1use alloc::vec::Vec;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Error {
31 Failed,
34}
35
36const LINE_LEN: usize = 64;
39
40#[must_use]
55pub fn encode(label: &str, der: &[u8]) -> alloc::string::String {
56 use core::fmt::Write;
57 let body = base64_encode(der);
58 let line_count = body.len().div_ceil(LINE_LEN);
60 let mut out =
61 alloc::string::String::with_capacity(body.len() + line_count + 2 * (label.len() + 16));
62 let _ = writeln!(out, "-----BEGIN {label}-----");
63 let mut start = 0;
64 while start < body.len() {
65 let end = (start + LINE_LEN).min(body.len());
66 out.push_str(&body[start..end]);
67 out.push('\n');
68 start = end;
69 }
70 let _ = writeln!(out, "-----END {label}-----");
71 out
72}
73
74pub fn decode(input: &str, expected_label: &str) -> Result<Vec<u8>, Error> {
86 let begin = alloc::format!("-----BEGIN {expected_label}-----");
87 let end = alloc::format!("-----END {expected_label}-----");
88
89 let begin_idx = input.find(&begin).ok_or(Error::Failed)?;
90 let after_begin = &input[begin_idx + begin.len()..];
91 let end_rel = after_begin.find(&end).ok_or(Error::Failed)?;
92 let body = &after_begin[..end_rel];
93
94 let mut stripped = alloc::string::String::with_capacity(body.len());
98 for ch in body.chars() {
99 if !ch.is_ascii_whitespace() {
100 stripped.push(ch);
101 }
102 }
103
104 base64_decode(&stripped).ok_or(Error::Failed)
105}
106
107const BASE64_ALPHABET: &[u8; 64] =
110 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
111
112#[must_use]
115fn base64_encode(input: &[u8]) -> alloc::string::String {
116 let mut out = alloc::string::String::with_capacity(input.len().div_ceil(3) * 4);
117 let mut i = 0;
118 while i + 3 <= input.len() {
119 let b0 = input[i];
120 let b1 = input[i + 1];
121 let b2 = input[i + 2];
122 out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
123 out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
124 out.push(BASE64_ALPHABET[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
125 out.push(BASE64_ALPHABET[(b2 & 0x3F) as usize] as char);
126 i += 3;
127 }
128 let rem = input.len() - i;
129 if rem == 1 {
130 let b0 = input[i];
131 out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
132 out.push(BASE64_ALPHABET[((b0 & 0x03) << 4) as usize] as char);
133 out.push('=');
134 out.push('=');
135 } else if rem == 2 {
136 let b0 = input[i];
137 let b1 = input[i + 1];
138 out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
139 out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
140 out.push(BASE64_ALPHABET[((b1 & 0x0F) << 2) as usize] as char);
141 out.push('=');
142 }
143 out
144}
145
146#[must_use]
149fn base64_decode(input: &str) -> Option<Vec<u8>> {
150 let bytes = input.as_bytes();
151 if bytes.len() % 4 != 0 {
152 return None;
153 }
154 if bytes.is_empty() {
155 return Some(Vec::new());
156 }
157
158 let pad = if bytes.ends_with(b"==") {
160 2usize
161 } else {
162 usize::from(bytes.ends_with(b"="))
163 };
164 let body_chars = bytes.len() - pad;
165
166 let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
167 let mut i = 0;
168 while i + 4 <= bytes.len() {
169 let last_group = i + 4 == bytes.len();
172 let v0 = base64_lookup(bytes[i])?;
173 let v1 = base64_lookup(bytes[i + 1])?;
174 let (v2, v3) = if last_group {
175 (
176 if i + 2 < body_chars {
177 base64_lookup(bytes[i + 2])?
178 } else {
179 if bytes[i + 2] != b'=' {
180 return None;
181 }
182 0
183 },
184 if i + 3 < body_chars {
185 base64_lookup(bytes[i + 3])?
186 } else {
187 if bytes[i + 3] != b'=' {
188 return None;
189 }
190 0
191 },
192 )
193 } else {
194 (base64_lookup(bytes[i + 2])?, base64_lookup(bytes[i + 3])?)
195 };
196
197 let b0 = (v0 << 2) | (v1 >> 4);
198 let b1 = (v1 << 4) | (v2 >> 2);
199 let b2 = (v2 << 6) | v3;
200
201 if last_group {
207 if pad == 2 && (v1 & 0x0F) != 0 {
208 return None;
209 }
210 if pad == 1 && (v2 & 0x03) != 0 {
211 return None;
212 }
213 }
214
215 out.push(b0);
216 if !last_group || pad <= 1 {
217 out.push(b1);
218 }
219 if !last_group || pad == 0 {
220 out.push(b2);
221 }
222 i += 4;
223 }
224 Some(out)
225}
226
227const fn base64_lookup(c: u8) -> Option<u8> {
231 Some(match c {
232 b'A'..=b'Z' => c - b'A',
233 b'a'..=b'z' => c - b'a' + 26,
234 b'0'..=b'9' => c - b'0' + 52,
235 b'+' => 62,
236 b'/' => 63,
237 _ => return None,
238 })
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
248 fn base64_round_trip_empty() {
249 let bytes: &[u8] = &[];
250 assert_eq!(base64_encode(bytes), "");
251 assert_eq!(base64_decode("").as_deref(), Some(bytes));
252 }
253
254 #[test]
255 fn base64_round_trip_one_byte() {
256 assert_eq!(base64_encode(b"f"), "Zg==");
258 assert_eq!(base64_decode("Zg==").as_deref(), Some(b"f".as_slice()));
259 }
260
261 #[test]
262 fn base64_round_trip_two_bytes() {
263 assert_eq!(base64_encode(b"fo"), "Zm8=");
264 assert_eq!(base64_decode("Zm8=").as_deref(), Some(b"fo".as_slice()));
265 }
266
267 #[test]
268 fn base64_round_trip_three_bytes() {
269 assert_eq!(base64_encode(b"foo"), "Zm9v");
270 assert_eq!(base64_decode("Zm9v").as_deref(), Some(b"foo".as_slice()));
271 }
272
273 #[test]
274 fn base64_rfc4648_test_vectors() {
275 for (raw, encoded) in [
277 ("", ""),
278 ("f", "Zg=="),
279 ("fo", "Zm8="),
280 ("foo", "Zm9v"),
281 ("foob", "Zm9vYg=="),
282 ("fooba", "Zm9vYmE="),
283 ("foobar", "Zm9vYmFy"),
284 ] {
285 assert_eq!(base64_encode(raw.as_bytes()), encoded);
286 assert_eq!(
287 base64_decode(encoded).as_deref(),
288 Some(raw.as_bytes()),
289 "decode {encoded:?}"
290 );
291 }
292 }
293
294 #[test]
295 fn base64_decode_rejects_bad_chars() {
296 assert!(base64_decode("Zm9*").is_none()); assert!(base64_decode("Zm9").is_none()); assert!(base64_decode("Z===").is_none()); assert!(base64_decode("====").is_none()); }
301
302 #[test]
308 fn base64_decode_rejects_non_canonical_pad_bits() {
309 assert!(base64_decode("Zh==").is_none());
311 assert!(base64_decode("Zg==").is_some());
313 assert!(base64_decode("Zm8=").is_some());
315 assert!(base64_decode("Zm9=").is_none());
317 }
318
319 #[test]
322 fn pem_round_trip_short() {
323 let der: &[u8] = &[0x30, 0x03, 0x02, 0x01, 0x05];
324 let pem = encode("EC PRIVATE KEY", der);
325 let recovered = decode(&pem, "EC PRIVATE KEY").expect("decode");
326 assert_eq!(recovered, der);
327 }
328
329 #[test]
330 fn pem_round_trip_long_wraps_at_64() {
331 let der: alloc::vec::Vec<u8> = (0..100u8).collect();
333 let pem = encode("PRIVATE KEY", &der);
334 for line in pem.lines() {
336 if line.starts_with("---") {
337 continue;
338 }
339 assert!(line.len() <= LINE_LEN, "body line too long: {line:?}");
340 }
341 let recovered = decode(&pem, "PRIVATE KEY").expect("decode");
342 assert_eq!(recovered, der);
343 }
344
345 #[test]
346 fn pem_label_must_match() {
347 let pem = encode("PRIVATE KEY", b"\x30\x00");
348 assert!(matches!(decode(&pem, "PUBLIC KEY"), Err(Error::Failed)));
349 }
350
351 #[test]
352 fn pem_decode_rejects_missing_begin() {
353 assert!(matches!(
354 decode("garbage", "PRIVATE KEY"),
355 Err(Error::Failed)
356 ));
357 }
358
359 #[test]
360 fn pem_decode_rejects_missing_end() {
361 let bad = "-----BEGIN PRIVATE KEY-----\nABCD\n";
362 assert!(matches!(decode(bad, "PRIVATE KEY"), Err(Error::Failed)));
363 }
364
365 #[test]
366 fn pem_decode_tolerates_crlf_and_extra_whitespace() {
367 let pem = "-----BEGIN PRIVATE KEY-----\r\n\
369 MAMC\r\n\
370 AQU=\r\n\
371 -----END PRIVATE KEY-----\r\n";
372 let recovered = decode(pem, "PRIVATE KEY").expect("decode");
373 assert_eq!(recovered, [0x30, 0x03, 0x02, 0x01, 0x05]);
374 }
375
376 #[test]
377 fn pem_encoded_form_is_strict() {
378 let der: alloc::vec::Vec<u8> = (0..200u8).collect();
379 let pem = encode("PRIVATE KEY", &der);
380 assert!(pem.ends_with('\n'));
382 assert!(!pem.contains('\r'));
383 assert!(pem.starts_with("-----BEGIN PRIVATE KEY-----\n"));
384 assert!(pem.contains("\n-----END PRIVATE KEY-----\n"));
385 }
386}