1use alloc::vec::Vec;
26
27pub type Error = crate::Error;
39
40const LINE_LEN: usize = 64;
43
44#[must_use]
59pub fn encode(label: &str, der: &[u8]) -> alloc::string::String {
60 use core::fmt::Write;
61 let body = base64_encode(der);
62 let line_count = body.len().div_ceil(LINE_LEN);
64 let mut out =
65 alloc::string::String::with_capacity(body.len() + line_count + 2 * (label.len() + 16));
66 let _ = writeln!(out, "-----BEGIN {label}-----");
67 let mut start = 0;
68 while start < body.len() {
69 let end = (start + LINE_LEN).min(body.len());
70 out.push_str(&body[start..end]);
71 out.push('\n');
72 start = end;
73 }
74 let _ = writeln!(out, "-----END {label}-----");
75 out
76}
77
78pub fn decode(input: &str, expected_label: &str) -> Result<Vec<u8>, Error> {
90 let begin = alloc::format!("-----BEGIN {expected_label}-----");
91 let end = alloc::format!("-----END {expected_label}-----");
92
93 let begin_idx = input.find(&begin).ok_or(Error::Failed)?;
94 let after_begin = &input[begin_idx + begin.len()..];
95 let end_rel = after_begin.find(&end).ok_or(Error::Failed)?;
96 let body = &after_begin[..end_rel];
97
98 let mut stripped = alloc::string::String::with_capacity(body.len());
102 for ch in body.chars() {
103 if !ch.is_ascii_whitespace() {
104 stripped.push(ch);
105 }
106 }
107
108 base64_decode(&stripped).ok_or(Error::Failed)
109}
110
111const BASE64_ALPHABET: &[u8; 64] =
114 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
115
116#[must_use]
119fn base64_encode(input: &[u8]) -> alloc::string::String {
120 let mut out = alloc::string::String::with_capacity(input.len().div_ceil(3) * 4);
121 let mut i = 0;
122 while i + 3 <= input.len() {
123 let b0 = input[i];
124 let b1 = input[i + 1];
125 let b2 = input[i + 2];
126 out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
127 out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
128 out.push(BASE64_ALPHABET[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
129 out.push(BASE64_ALPHABET[(b2 & 0x3F) as usize] as char);
130 i += 3;
131 }
132 let rem = input.len() - i;
133 if rem == 1 {
134 let b0 = input[i];
135 out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
136 out.push(BASE64_ALPHABET[((b0 & 0x03) << 4) as usize] as char);
137 out.push('=');
138 out.push('=');
139 } else if rem == 2 {
140 let b0 = input[i];
141 let b1 = input[i + 1];
142 out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
143 out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
144 out.push(BASE64_ALPHABET[((b1 & 0x0F) << 2) as usize] as char);
145 out.push('=');
146 }
147 out
148}
149
150#[must_use]
153fn base64_decode(input: &str) -> Option<Vec<u8>> {
154 let bytes = input.as_bytes();
155 if bytes.len() % 4 != 0 {
156 return None;
157 }
158 if bytes.is_empty() {
159 return Some(Vec::new());
160 }
161
162 let pad = if bytes.ends_with(b"==") {
164 2usize
165 } else {
166 usize::from(bytes.ends_with(b"="))
167 };
168 let body_chars = bytes.len() - pad;
169
170 let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
171 let mut i = 0;
172 while i + 4 <= bytes.len() {
173 let last_group = i + 4 == bytes.len();
176 let v0 = base64_lookup(bytes[i])?;
177 let v1 = base64_lookup(bytes[i + 1])?;
178 let (v2, v3) = if last_group {
179 (
180 if i + 2 < body_chars {
181 base64_lookup(bytes[i + 2])?
182 } else {
183 if bytes[i + 2] != b'=' {
184 return None;
185 }
186 0
187 },
188 if i + 3 < body_chars {
189 base64_lookup(bytes[i + 3])?
190 } else {
191 if bytes[i + 3] != b'=' {
192 return None;
193 }
194 0
195 },
196 )
197 } else {
198 (base64_lookup(bytes[i + 2])?, base64_lookup(bytes[i + 3])?)
199 };
200
201 let b0 = (v0 << 2) | (v1 >> 4);
202 let b1 = (v1 << 4) | (v2 >> 2);
203 let b2 = (v2 << 6) | v3;
204
205 if last_group {
211 if pad == 2 && (v1 & 0x0F) != 0 {
212 return None;
213 }
214 if pad == 1 && (v2 & 0x03) != 0 {
215 return None;
216 }
217 }
218
219 out.push(b0);
220 if !last_group || pad <= 1 {
221 out.push(b1);
222 }
223 if !last_group || pad == 0 {
224 out.push(b2);
225 }
226 i += 4;
227 }
228 Some(out)
229}
230
231const fn base64_lookup(c: u8) -> Option<u8> {
235 Some(match c {
236 b'A'..=b'Z' => c - b'A',
237 b'a'..=b'z' => c - b'a' + 26,
238 b'0'..=b'9' => c - b'0' + 52,
239 b'+' => 62,
240 b'/' => 63,
241 _ => return None,
242 })
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
252 fn base64_round_trip_empty() {
253 let bytes: &[u8] = &[];
254 assert_eq!(base64_encode(bytes), "");
255 assert_eq!(base64_decode("").as_deref(), Some(bytes));
256 }
257
258 #[test]
259 fn base64_round_trip_one_byte() {
260 assert_eq!(base64_encode(b"f"), "Zg==");
262 assert_eq!(base64_decode("Zg==").as_deref(), Some(b"f".as_slice()));
263 }
264
265 #[test]
266 fn base64_round_trip_two_bytes() {
267 assert_eq!(base64_encode(b"fo"), "Zm8=");
268 assert_eq!(base64_decode("Zm8=").as_deref(), Some(b"fo".as_slice()));
269 }
270
271 #[test]
272 fn base64_round_trip_three_bytes() {
273 assert_eq!(base64_encode(b"foo"), "Zm9v");
274 assert_eq!(base64_decode("Zm9v").as_deref(), Some(b"foo".as_slice()));
275 }
276
277 #[test]
278 fn base64_rfc4648_test_vectors() {
279 for (raw, encoded) in [
281 ("", ""),
282 ("f", "Zg=="),
283 ("fo", "Zm8="),
284 ("foo", "Zm9v"),
285 ("foob", "Zm9vYg=="),
286 ("fooba", "Zm9vYmE="),
287 ("foobar", "Zm9vYmFy"),
288 ] {
289 assert_eq!(base64_encode(raw.as_bytes()), encoded);
290 assert_eq!(
291 base64_decode(encoded).as_deref(),
292 Some(raw.as_bytes()),
293 "decode {encoded:?}"
294 );
295 }
296 }
297
298 #[test]
299 fn base64_decode_rejects_bad_chars() {
300 assert!(base64_decode("Zm9*").is_none()); assert!(base64_decode("Zm9").is_none()); assert!(base64_decode("Z===").is_none()); assert!(base64_decode("====").is_none()); }
305
306 #[test]
312 fn base64_decode_rejects_non_canonical_pad_bits() {
313 assert!(base64_decode("Zh==").is_none());
315 assert!(base64_decode("Zg==").is_some());
317 assert!(base64_decode("Zm8=").is_some());
319 assert!(base64_decode("Zm9=").is_none());
321 }
322
323 #[test]
326 fn pem_round_trip_short() {
327 let der: &[u8] = &[0x30, 0x03, 0x02, 0x01, 0x05];
328 let pem = encode("EC PRIVATE KEY", der);
329 let recovered = decode(&pem, "EC PRIVATE KEY").expect("decode");
330 assert_eq!(recovered, der);
331 }
332
333 #[test]
334 fn pem_round_trip_long_wraps_at_64() {
335 let der: alloc::vec::Vec<u8> = (0..100u8).collect();
337 let pem = encode("PRIVATE KEY", &der);
338 for line in pem.lines() {
340 if line.starts_with("---") {
341 continue;
342 }
343 assert!(line.len() <= LINE_LEN, "body line too long: {line:?}");
344 }
345 let recovered = decode(&pem, "PRIVATE KEY").expect("decode");
346 assert_eq!(recovered, der);
347 }
348
349 #[test]
350 fn pem_label_must_match() {
351 let pem = encode("PRIVATE KEY", b"\x30\x00");
352 assert!(matches!(decode(&pem, "PUBLIC KEY"), Err(Error::Failed)));
353 }
354
355 #[test]
356 fn pem_decode_rejects_missing_begin() {
357 assert!(matches!(
358 decode("garbage", "PRIVATE KEY"),
359 Err(Error::Failed)
360 ));
361 }
362
363 #[test]
364 fn pem_decode_rejects_missing_end() {
365 let bad = "-----BEGIN PRIVATE KEY-----\nABCD\n";
366 assert!(matches!(decode(bad, "PRIVATE KEY"), Err(Error::Failed)));
367 }
368
369 #[test]
370 fn pem_decode_tolerates_crlf_and_extra_whitespace() {
371 let pem = "-----BEGIN PRIVATE KEY-----\r\n\
373 MAMC\r\n\
374 AQU=\r\n\
375 -----END PRIVATE KEY-----\r\n";
376 let recovered = decode(pem, "PRIVATE KEY").expect("decode");
377 assert_eq!(recovered, [0x30, 0x03, 0x02, 0x01, 0x05]);
378 }
379
380 #[test]
381 fn pem_encoded_form_is_strict() {
382 let der: alloc::vec::Vec<u8> = (0..200u8).collect();
383 let pem = encode("PRIVATE KEY", &der);
384 assert!(pem.ends_with('\n'));
386 assert!(!pem.contains('\r'));
387 assert!(pem.starts_with("-----BEGIN PRIVATE KEY-----\n"));
388 assert!(pem.contains("\n-----END PRIVATE KEY-----\n"));
389 }
390}