Skip to main content

reddb_crypto/
key.rs

1//! Encryption-key parsing — a mandatory encrypt parameter homed here
2//! per #1053 / ADR 0054 (carried forward from the retired RDEP
3//! envelope).
4//!
5//! Accepts a 32-byte AES-256 key as either 64 hex chars or
6//! (un)padded standard base64, tolerating surrounding whitespace
7//! (e.g. the trailing newline `kubectl create secret` leaves on a
8//! key file). Reading the key from the environment stays in
9//! `reddb-server` because it layers a server-specific
10//! file-fallback convention on top of this parser.
11
12/// Parse a 32-byte AES key from a string — accepts hex (64 chars) or
13/// unpadded/padded standard base64 (43 or 44 chars). Tolerates
14/// leading/trailing whitespace including newlines.
15pub fn parse_key(raw: &str) -> Result<[u8; 32], String> {
16    let trimmed = raw.trim();
17    // Hex: exactly 64 hex digits.
18    if trimmed.len() == 64 && trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
19        let mut out = [0u8; 32];
20        for (i, byte) in out.iter_mut().enumerate() {
21            *byte = u8::from_str_radix(&trimmed[i * 2..i * 2 + 2], 16)
22                .map_err(|err| format!("invalid hex key byte {i}: {err}"))?;
23        }
24        return Ok(out);
25    }
26    // Base64: standard alphabet, 32 raw bytes → 44 chars padded or
27    // 43 unpadded. A tiny inline decoder avoids pulling a base64
28    // crate just for this.
29    let decoded = decode_base64(trimmed)
30        .map_err(|err| format!("key is neither 64-hex nor base64 (decode error: {err})"))?;
31    if decoded.len() != 32 {
32        return Err(format!(
33            "decoded key is {} bytes; AES-256-GCM requires exactly 32",
34            decoded.len()
35        ));
36    }
37    let mut out = [0u8; 32];
38    out.copy_from_slice(&decoded);
39    Ok(out)
40}
41
42fn decode_base64(s: &str) -> Result<Vec<u8>, String> {
43    fn val(c: u8) -> Option<u8> {
44        match c {
45            b'A'..=b'Z' => Some(c - b'A'),
46            b'a'..=b'z' => Some(c - b'a' + 26),
47            b'0'..=b'9' => Some(c - b'0' + 52),
48            b'+' => Some(62),
49            b'/' => Some(63),
50            _ => None,
51        }
52    }
53    let bytes: Vec<u8> = s
54        .bytes()
55        .filter(|b| !b.is_ascii_whitespace() && *b != b'=')
56        .collect();
57    let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
58    let mut i = 0;
59    while i + 3 < bytes.len() {
60        let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
61        let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
62        let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
63        let d = val(bytes[i + 3]).ok_or_else(|| format!("invalid base64 char at {}", i + 3))?;
64        out.push((a << 2) | (b >> 4));
65        out.push(((b & 0x0F) << 4) | (c >> 2));
66        out.push(((c & 0x03) << 6) | d);
67        i += 4;
68    }
69    let rem = bytes.len() - i;
70    match rem {
71        0 => {}
72        2 => {
73            let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
74            let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
75            out.push((a << 2) | (b >> 4));
76        }
77        3 => {
78            let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
79            let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
80            let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
81            out.push((a << 2) | (b >> 4));
82            out.push(((b & 0x0F) << 4) | (c >> 2));
83        }
84        _ => return Err(format!("invalid base64 length remainder {rem}")),
85    }
86    Ok(out)
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn parse_key_accepts_hex() {
95        let hex = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20";
96        let key = parse_key(hex).unwrap();
97        assert_eq!(key[0], 0x01);
98        assert_eq!(key[31], 0x20);
99    }
100
101    #[test]
102    fn parse_key_accepts_hex_with_whitespace() {
103        let hex = "  0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20\n";
104        assert!(parse_key(hex).is_ok());
105    }
106
107    #[test]
108    fn parse_key_rejects_wrong_length() {
109        assert!(parse_key("ab").is_err());
110        assert!(parse_key("zz".repeat(32).as_str()).is_err()); // 64 chars but not hex
111    }
112
113    #[test]
114    fn parse_key_accepts_base64() {
115        // 32 bytes of 0xAB base64-encoded, encoded inline to avoid a crate.
116        let raw = vec![0xAB_u8; 32];
117        let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
118        let mut out = String::new();
119        let mut i = 0;
120        while i + 3 <= raw.len() {
121            let n = ((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8) | (raw[i + 2] as u32);
122            out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
123            out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
124            out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
125            out.push(alphabet[(n & 0x3F) as usize] as char);
126            i += 3;
127        }
128        if i < raw.len() {
129            let rem = raw.len() - i;
130            let n = if rem == 1 {
131                (raw[i] as u32) << 16
132            } else {
133                ((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8)
134            };
135            out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
136            out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
137            if rem == 2 {
138                out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
139            }
140        }
141        let key = parse_key(&out).unwrap();
142        assert_eq!(key, [0xABu8; 32]);
143    }
144
145    #[test]
146    fn decode_base64_accepts_symbols_whitespace_and_padding() {
147        assert_eq!(decode_base64(" +/+/==\n").unwrap(), vec![251, 255, 191]);
148    }
149
150    #[test]
151    fn decode_base64_rejects_invalid_chars_by_position() {
152        for input in [
153            "!AAA", "A!AA", "AA!A", "AAA!", "!A", "A!", "!AA", "A!A", "AA!",
154        ] {
155            assert!(decode_base64(input).is_err(), "{input}");
156        }
157    }
158
159    #[test]
160    fn decode_base64_rejects_single_char_remainder() {
161        let err = decode_base64("A").unwrap_err();
162        assert!(err.contains("invalid base64 length remainder 1"));
163    }
164}