Skip to main content

gmcrypto_core/
pem.rs

1//! Hand-rolled PEM (RFC 7468) codec.
2//!
3//! Wraps and unwraps `-----BEGIN <LABEL>-----` ... `-----END <LABEL>-----`
4//! armor around an arbitrary DER blob. Used by [`crate::pkcs8`],
5//! [`crate::spki`], and [`crate::sec1`] for on-disk format support.
6//!
7//! # Posture
8//!
9//! - **Liberal decoder, conservative encoder.** [`decode`] accepts the
10//!   relaxed RFC 7468 production: arbitrary whitespace (including CR,
11//!   LF, tab, space) anywhere inside the body, and either CRLF or LF
12//!   line terminators around the boundary lines. [`encode`] emits the
13//!   strict RFC 1421 production: 64 base-64 characters per line, LF
14//!   terminator, no trailing whitespace.
15//! - **No external dependencies.** The base64 codec is embedded below
16//!   (~80 LOC) per the v0.3 scope's zero-runtime-deps stance (Q7.1).
17//! - **`no_std` + `alloc`.** No file-loading helpers in this module.
18//!
19//! # Failure-mode invariant
20//!
21//! [`decode`] returns `Result<Vec<u8>, Error>` with a single
22//! [`Error::Failed`] variant. Distinguishing "wrong label" from "bad
23//! base64" from "missing END line" is forbidden — see `CLAUDE.md`.
24
25use alloc::vec::Vec;
26
27/// PEM codec failure. Single uninformative variant per the
28/// project's failure-mode invariant.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Error {
31    /// Decoding failed for any reason (malformed boundaries, unknown
32    /// label, invalid base64, length-mismatched body).
33    Failed,
34}
35
36/// Strict line length emitted by [`encode`]. RFC 1421 §4.3.2.4 fixes
37/// 64 base-64 characters per line; RFC 7468 §3 keeps the same.
38const LINE_LEN: usize = 64;
39
40/// Encode `der` as a PEM block with the given `label`. Output is the
41/// strict RFC 1421 form: 64 chars per line, LF terminators, no
42/// trailing whitespace.
43///
44/// `label` must be ASCII per RFC 7468 §2 — non-ASCII labels would
45/// round-trip but reject under strict-conformant decoders. The
46/// callers in this crate use fixed labels (`"PRIVATE KEY"`,
47/// `"PUBLIC KEY"`, `"ENCRYPTED PRIVATE KEY"`, `"EC PRIVATE KEY"`),
48/// all ASCII.
49///
50/// # Panics
51///
52/// Never (encoded length is bounded by `4 · der.len() / 3 + small`,
53/// well below the `Vec` allocation ceiling on any realistic input).
54#[must_use]
55pub fn encode(label: &str, der: &[u8]) -> alloc::string::String {
56    use core::fmt::Write;
57    let body = base64_encode(der);
58    // 4-line preamble + (body chunked into 64-char lines) + 4-line postamble.
59    let line_count = body.len().div_ceil(LINE_LEN);
60    let mut out =
61        alloc::string::String::with_capacity(body.len() + line_count + 2 * (label.len() + 16));
62    let _ = writeln!(out, "-----BEGIN {label}-----");
63    let mut start = 0;
64    while start < body.len() {
65        let end = (start + LINE_LEN).min(body.len());
66        out.push_str(&body[start..end]);
67        out.push('\n');
68        start = end;
69    }
70    let _ = writeln!(out, "-----END {label}-----");
71    out
72}
73
74/// Decode a PEM block, returning the raw DER bytes. The block's label
75/// must equal `expected_label` exactly.
76///
77/// Liberal on whitespace (RFC 7468 §3): tabs, spaces, CR, and LF are
78/// all stripped inside the body. The label must match exactly — case
79/// sensitive, no fuzzy-match.
80///
81/// # Errors
82///
83/// Returns [`Error::Failed`] for any malformed input. Single
84/// uninformative variant per the project's failure-mode invariant.
85pub fn decode(input: &str, expected_label: &str) -> Result<Vec<u8>, Error> {
86    let begin = alloc::format!("-----BEGIN {expected_label}-----");
87    let end = alloc::format!("-----END {expected_label}-----");
88
89    let begin_idx = input.find(&begin).ok_or(Error::Failed)?;
90    let after_begin = &input[begin_idx + begin.len()..];
91    let end_rel = after_begin.find(&end).ok_or(Error::Failed)?;
92    let body = &after_begin[..end_rel];
93
94    // Strip whitespace from the body. Anything else (printable
95    // non-base64, non-ASCII) gets fed through to base64_decode, which
96    // rejects it.
97    let mut stripped = alloc::string::String::with_capacity(body.len());
98    for ch in body.chars() {
99        if !ch.is_ascii_whitespace() {
100            stripped.push(ch);
101        }
102    }
103
104    base64_decode(&stripped).ok_or(Error::Failed)
105}
106
107// --- base64 codec (RFC 4648 §4, "standard alphabet") ---
108
109const BASE64_ALPHABET: &[u8; 64] =
110    b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
111
112/// Encode `input` as standard base64 with `=` padding. Output is
113/// pure ASCII (no line breaks; [`encode`] inserts them).
114#[must_use]
115fn base64_encode(input: &[u8]) -> alloc::string::String {
116    let mut out = alloc::string::String::with_capacity(input.len().div_ceil(3) * 4);
117    let mut i = 0;
118    while i + 3 <= input.len() {
119        let b0 = input[i];
120        let b1 = input[i + 1];
121        let b2 = input[i + 2];
122        out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
123        out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
124        out.push(BASE64_ALPHABET[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
125        out.push(BASE64_ALPHABET[(b2 & 0x3F) as usize] as char);
126        i += 3;
127    }
128    let rem = input.len() - i;
129    if rem == 1 {
130        let b0 = input[i];
131        out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
132        out.push(BASE64_ALPHABET[((b0 & 0x03) << 4) as usize] as char);
133        out.push('=');
134        out.push('=');
135    } else if rem == 2 {
136        let b0 = input[i];
137        let b1 = input[i + 1];
138        out.push(BASE64_ALPHABET[(b0 >> 2) as usize] as char);
139        out.push(BASE64_ALPHABET[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
140        out.push(BASE64_ALPHABET[((b1 & 0x0F) << 2) as usize] as char);
141        out.push('=');
142    }
143    out
144}
145
146/// Decode a base64 string (no whitespace; caller pre-stripped).
147/// Returns `None` for any malformed input.
148#[must_use]
149fn base64_decode(input: &str) -> Option<Vec<u8>> {
150    let bytes = input.as_bytes();
151    if bytes.len() % 4 != 0 {
152        return None;
153    }
154    if bytes.is_empty() {
155        return Some(Vec::new());
156    }
157
158    // Determine pad count from the suffix.
159    let pad = if bytes.ends_with(b"==") {
160        2usize
161    } else {
162        usize::from(bytes.ends_with(b"="))
163    };
164    let body_chars = bytes.len() - pad;
165
166    let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
167    let mut i = 0;
168    while i + 4 <= bytes.len() {
169        // Decode four input characters → three output bytes (minus
170        // pad-driven trim on the final group).
171        let last_group = i + 4 == bytes.len();
172        let v0 = base64_lookup(bytes[i])?;
173        let v1 = base64_lookup(bytes[i + 1])?;
174        let (v2, v3) = if last_group {
175            (
176                if i + 2 < body_chars {
177                    base64_lookup(bytes[i + 2])?
178                } else {
179                    if bytes[i + 2] != b'=' {
180                        return None;
181                    }
182                    0
183                },
184                if i + 3 < body_chars {
185                    base64_lookup(bytes[i + 3])?
186                } else {
187                    if bytes[i + 3] != b'=' {
188                        return None;
189                    }
190                    0
191                },
192            )
193        } else {
194            (base64_lookup(bytes[i + 2])?, base64_lookup(bytes[i + 3])?)
195        };
196
197        let b0 = (v0 << 2) | (v1 >> 4);
198        let b1 = (v1 << 4) | (v2 >> 2);
199        let b2 = (v2 << 6) | v3;
200
201        // Strict-canonical: the bits of the final-group sextets that
202        // would have encoded the dropped output bytes must be zero.
203        // pad=2: low 4 bits of v1 encode part of `b1` (which we drop)
204        // and must be zero. pad=1: low 2 bits of v2 encode part of
205        // `b2` (which we drop) and must be zero.
206        if last_group {
207            if pad == 2 && (v1 & 0x0F) != 0 {
208                return None;
209            }
210            if pad == 1 && (v2 & 0x03) != 0 {
211                return None;
212            }
213        }
214
215        out.push(b0);
216        if !last_group || pad <= 1 {
217            out.push(b1);
218        }
219        if !last_group || pad == 0 {
220            out.push(b2);
221        }
222        i += 4;
223    }
224    Some(out)
225}
226
227/// Reverse-lookup for the standard base64 alphabet. Returns `None` for
228/// any non-alphabet byte (including `=`, which the caller handles
229/// out-of-band via the suffix scan).
230const fn base64_lookup(c: u8) -> Option<u8> {
231    Some(match c {
232        b'A'..=b'Z' => c - b'A',
233        b'a'..=b'z' => c - b'a' + 26,
234        b'0'..=b'9' => c - b'0' + 52,
235        b'+' => 62,
236        b'/' => 63,
237        _ => return None,
238    })
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    // ---------- base64 codec ----------
246
247    #[test]
248    fn base64_round_trip_empty() {
249        let bytes: &[u8] = &[];
250        assert_eq!(base64_encode(bytes), "");
251        assert_eq!(base64_decode("").as_deref(), Some(bytes));
252    }
253
254    #[test]
255    fn base64_round_trip_one_byte() {
256        // Single byte → "AA==" (RFC 4648 example: "f" = "Zg==").
257        assert_eq!(base64_encode(b"f"), "Zg==");
258        assert_eq!(base64_decode("Zg==").as_deref(), Some(b"f".as_slice()));
259    }
260
261    #[test]
262    fn base64_round_trip_two_bytes() {
263        assert_eq!(base64_encode(b"fo"), "Zm8=");
264        assert_eq!(base64_decode("Zm8=").as_deref(), Some(b"fo".as_slice()));
265    }
266
267    #[test]
268    fn base64_round_trip_three_bytes() {
269        assert_eq!(base64_encode(b"foo"), "Zm9v");
270        assert_eq!(base64_decode("Zm9v").as_deref(), Some(b"foo".as_slice()));
271    }
272
273    #[test]
274    fn base64_rfc4648_test_vectors() {
275        // RFC 4648 §10.
276        for (raw, encoded) in [
277            ("", ""),
278            ("f", "Zg=="),
279            ("fo", "Zm8="),
280            ("foo", "Zm9v"),
281            ("foob", "Zm9vYg=="),
282            ("fooba", "Zm9vYmE="),
283            ("foobar", "Zm9vYmFy"),
284        ] {
285            assert_eq!(base64_encode(raw.as_bytes()), encoded);
286            assert_eq!(
287                base64_decode(encoded).as_deref(),
288                Some(raw.as_bytes()),
289                "decode {encoded:?}"
290            );
291        }
292    }
293
294    #[test]
295    fn base64_decode_rejects_bad_chars() {
296        assert!(base64_decode("Zm9*").is_none()); // '*' not in alphabet
297        assert!(base64_decode("Zm9").is_none()); // length not multiple of 4
298        assert!(base64_decode("Z===").is_none()); // 3 pads invalid
299        assert!(base64_decode("====").is_none()); // all-pad invalid
300    }
301
302    /// Strict canonical: non-zero pad bits in the final quantum reject.
303    /// `Zg==` is the canonical encoding of `[0x66]`. `Zh==` would
304    /// embed `0x68` in v1's low 4 bits — non-canonical because the
305    /// encoded byte is still `0x66` but the round-trip would silently
306    /// drop the extra bits.
307    #[test]
308    fn base64_decode_rejects_non_canonical_pad_bits() {
309        // 'Z' = 25, 'h' = 33. v1 = 33 = 0b100001. Low 4 bits = 0b0001 ≠ 0.
310        assert!(base64_decode("Zh==").is_none());
311        // 'Z' = 25, 'g' = 32. v1 = 32 = 0b100000. Low 4 bits = 0 — accept.
312        assert!(base64_decode("Zg==").is_some());
313        // 'Z' = 25, 'g' = 32, '8' = 60. v2 = 60 = 0b111100. Low 2 bits = 0 — accept.
314        assert!(base64_decode("Zm8=").is_some());
315        // Mutate to v2 with non-zero low 2 bits: '9' = 61 = 0b111101. Low 2 = 0b01 ≠ 0.
316        assert!(base64_decode("Zm9=").is_none());
317    }
318
319    // ---------- PEM ----------
320
321    #[test]
322    fn pem_round_trip_short() {
323        let der: &[u8] = &[0x30, 0x03, 0x02, 0x01, 0x05];
324        let pem = encode("EC PRIVATE KEY", der);
325        let recovered = decode(&pem, "EC PRIVATE KEY").expect("decode");
326        assert_eq!(recovered, der);
327    }
328
329    #[test]
330    fn pem_round_trip_long_wraps_at_64() {
331        // 100 bytes of DER → 168 chars of base64 → 3 lines of 64/64/40.
332        let der: alloc::vec::Vec<u8> = (0..100u8).collect();
333        let pem = encode("PRIVATE KEY", &der);
334        // Body lines all ≤ 64 chars.
335        for line in pem.lines() {
336            if line.starts_with("---") {
337                continue;
338            }
339            assert!(line.len() <= LINE_LEN, "body line too long: {line:?}");
340        }
341        let recovered = decode(&pem, "PRIVATE KEY").expect("decode");
342        assert_eq!(recovered, der);
343    }
344
345    #[test]
346    fn pem_label_must_match() {
347        let pem = encode("PRIVATE KEY", b"\x30\x00");
348        assert!(matches!(decode(&pem, "PUBLIC KEY"), Err(Error::Failed)));
349    }
350
351    #[test]
352    fn pem_decode_rejects_missing_begin() {
353        assert!(matches!(
354            decode("garbage", "PRIVATE KEY"),
355            Err(Error::Failed)
356        ));
357    }
358
359    #[test]
360    fn pem_decode_rejects_missing_end() {
361        let bad = "-----BEGIN PRIVATE KEY-----\nABCD\n";
362        assert!(matches!(decode(bad, "PRIVATE KEY"), Err(Error::Failed)));
363    }
364
365    #[test]
366    fn pem_decode_tolerates_crlf_and_extra_whitespace() {
367        // CRLF terminators + extra whitespace inside the body.
368        let pem = "-----BEGIN PRIVATE KEY-----\r\n\
369                   MAMC\r\n\
370                   AQU=\r\n\
371                   -----END PRIVATE KEY-----\r\n";
372        let recovered = decode(pem, "PRIVATE KEY").expect("decode");
373        assert_eq!(recovered, [0x30, 0x03, 0x02, 0x01, 0x05]);
374    }
375
376    #[test]
377    fn pem_encoded_form_is_strict() {
378        let der: alloc::vec::Vec<u8> = (0..200u8).collect();
379        let pem = encode("PRIVATE KEY", &der);
380        // Strict form: trailing newline; no \r; preamble + body + postamble.
381        assert!(pem.ends_with('\n'));
382        assert!(!pem.contains('\r'));
383        assert!(pem.starts_with("-----BEGIN PRIVATE KEY-----\n"));
384        assert!(pem.contains("\n-----END PRIVATE KEY-----\n"));
385    }
386}