1pub fn parse_key(raw: &str) -> Result<[u8; 32], String> {
16 let trimmed = raw.trim();
17 if trimmed.len() == 64 && trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
19 let mut out = [0u8; 32];
20 for (i, byte) in out.iter_mut().enumerate() {
21 *byte = u8::from_str_radix(&trimmed[i * 2..i * 2 + 2], 16)
22 .map_err(|err| format!("invalid hex key byte {i}: {err}"))?;
23 }
24 return Ok(out);
25 }
26 let decoded = decode_base64(trimmed)
30 .map_err(|err| format!("key is neither 64-hex nor base64 (decode error: {err})"))?;
31 if decoded.len() != 32 {
32 return Err(format!(
33 "decoded key is {} bytes; AES-256-GCM requires exactly 32",
34 decoded.len()
35 ));
36 }
37 let mut out = [0u8; 32];
38 out.copy_from_slice(&decoded);
39 Ok(out)
40}
41
42fn decode_base64(s: &str) -> Result<Vec<u8>, String> {
43 fn val(c: u8) -> Option<u8> {
44 match c {
45 b'A'..=b'Z' => Some(c - b'A'),
46 b'a'..=b'z' => Some(c - b'a' + 26),
47 b'0'..=b'9' => Some(c - b'0' + 52),
48 b'+' => Some(62),
49 b'/' => Some(63),
50 _ => None,
51 }
52 }
53 let bytes: Vec<u8> = s
54 .bytes()
55 .filter(|b| !b.is_ascii_whitespace() && *b != b'=')
56 .collect();
57 let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
58 let mut i = 0;
59 while i + 3 < bytes.len() {
60 let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
61 let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
62 let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
63 let d = val(bytes[i + 3]).ok_or_else(|| format!("invalid base64 char at {}", i + 3))?;
64 out.push((a << 2) | (b >> 4));
65 out.push(((b & 0x0F) << 4) | (c >> 2));
66 out.push(((c & 0x03) << 6) | d);
67 i += 4;
68 }
69 let rem = bytes.len() - i;
70 match rem {
71 0 => {}
72 2 => {
73 let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
74 let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
75 out.push((a << 2) | (b >> 4));
76 }
77 3 => {
78 let a = val(bytes[i]).ok_or_else(|| format!("invalid base64 char at {i}"))?;
79 let b = val(bytes[i + 1]).ok_or_else(|| format!("invalid base64 char at {}", i + 1))?;
80 let c = val(bytes[i + 2]).ok_or_else(|| format!("invalid base64 char at {}", i + 2))?;
81 out.push((a << 2) | (b >> 4));
82 out.push(((b & 0x0F) << 4) | (c >> 2));
83 }
84 _ => return Err(format!("invalid base64 length remainder {rem}")),
85 }
86 Ok(out)
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn parse_key_accepts_hex() {
95 let hex = "0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20";
96 let key = parse_key(hex).unwrap();
97 assert_eq!(key[0], 0x01);
98 assert_eq!(key[31], 0x20);
99 }
100
101 #[test]
102 fn parse_key_accepts_hex_with_whitespace() {
103 let hex = " 0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20\n";
104 assert!(parse_key(hex).is_ok());
105 }
106
107 #[test]
108 fn parse_key_rejects_wrong_length() {
109 assert!(parse_key("ab").is_err());
110 assert!(parse_key("zz".repeat(32).as_str()).is_err()); }
112
113 #[test]
114 fn parse_key_accepts_base64() {
115 let raw = vec![0xAB_u8; 32];
117 let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
118 let mut out = String::new();
119 let mut i = 0;
120 while i + 3 <= raw.len() {
121 let n = ((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8) | (raw[i + 2] as u32);
122 out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
123 out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
124 out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
125 out.push(alphabet[(n & 0x3F) as usize] as char);
126 i += 3;
127 }
128 if i < raw.len() {
129 let rem = raw.len() - i;
130 let n = if rem == 1 {
131 (raw[i] as u32) << 16
132 } else {
133 ((raw[i] as u32) << 16) | ((raw[i + 1] as u32) << 8)
134 };
135 out.push(alphabet[((n >> 18) & 0x3F) as usize] as char);
136 out.push(alphabet[((n >> 12) & 0x3F) as usize] as char);
137 if rem == 2 {
138 out.push(alphabet[((n >> 6) & 0x3F) as usize] as char);
139 }
140 }
141 let key = parse_key(&out).unwrap();
142 assert_eq!(key, [0xABu8; 32]);
143 }
144
145 #[test]
146 fn decode_base64_accepts_symbols_whitespace_and_padding() {
147 assert_eq!(decode_base64(" +/+/==\n").unwrap(), vec![251, 255, 191]);
148 }
149
150 #[test]
151 fn decode_base64_rejects_invalid_chars_by_position() {
152 for input in [
153 "!AAA", "A!AA", "AA!A", "AAA!", "!A", "A!", "!AA", "A!A", "AA!",
154 ] {
155 assert!(decode_base64(input).is_err(), "{input}");
156 }
157 }
158
159 #[test]
160 fn decode_base64_rejects_single_char_remainder() {
161 let err = decode_base64("A").unwrap_err();
162 assert!(err.contains("invalid base64 length remainder 1"));
163 }
164}