1use num_bigint::BigUint;
4use num_traits::Num;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
10pub enum KeyExchangeMessage {
11 RsaRequest { n: BigUint, e: BigUint },
13 RsaReply { c: BigUint },
15 DhRequest {
17 g: BigUint,
19 p: BigUint,
21 a_pub: BigUint,
23 },
24 DhReply { b_pub: BigUint },
26 ClearKey { key: Vec<u8> },
28}
29
30#[derive(Debug, Error)]
32pub enum KeyExchangeError {
33 #[error("key exchange body is not utf-8")]
35 NonUtf8,
36 #[error("malformed key exchange body")]
38 Malformed,
39 #[error("key exchange numbers must be lower-case hexadecimal")]
41 NotLowerHex,
42 #[error("invalid hexadecimal payload")]
44 HexDecode,
45}
46
47pub fn parse_key_exchange(body: &[u8]) -> Result<Option<KeyExchangeMessage>, KeyExchangeError> {
51 let Ok(text) = std::str::from_utf8(body) else {
52 return Ok(None);
53 };
54
55 if let Some(rest) = text.strip_prefix("RSA key exchange request=") {
56 let rest = rest.strip_suffix('.').ok_or(KeyExchangeError::Malformed)?;
57 let (n_hex, e_hex) = rest.split_once(',').ok_or(KeyExchangeError::Malformed)?;
58 let n = parse_biguint_hex(n_hex)?;
59 let e = parse_biguint_hex(e_hex)?;
60 return Ok(Some(KeyExchangeMessage::RsaRequest { n, e }));
61 }
62 if let Some(rest) = text.strip_prefix("RSA key exchange reply=") {
63 let rest = rest.strip_suffix('.').ok_or(KeyExchangeError::Malformed)?;
64 let c = parse_biguint_hex(rest)?;
65 return Ok(Some(KeyExchangeMessage::RsaReply { c }));
66 }
67 if let Some(rest) = text.strip_prefix("DH key exchange request=") {
68 let rest = rest.strip_suffix('.').ok_or(KeyExchangeError::Malformed)?;
69 let mut parts = rest.split(',');
70 let g = parse_biguint_hex(parts.next().ok_or(KeyExchangeError::Malformed)?)?;
71 let p = parse_biguint_hex(parts.next().ok_or(KeyExchangeError::Malformed)?)?;
72 let a_pub = parse_biguint_hex(parts.next().ok_or(KeyExchangeError::Malformed)?)?;
73 if parts.next().is_some() {
74 return Err(KeyExchangeError::Malformed);
75 }
76 return Ok(Some(KeyExchangeMessage::DhRequest { g, p, a_pub }));
77 }
78 if let Some(rest) = text.strip_prefix("DH key exchange reply=") {
79 let rest = rest.strip_suffix('.').ok_or(KeyExchangeError::Malformed)?;
80 let b_pub = parse_biguint_hex(rest)?;
81 return Ok(Some(KeyExchangeMessage::DhReply { b_pub }));
82 }
83 if let Some(rest) = text.strip_prefix("Clear key exchange=") {
84 let rest = rest.strip_suffix('.').ok_or(KeyExchangeError::Malformed)?;
85 let key = parse_hex_bytes(rest)?;
86 return Ok(Some(KeyExchangeMessage::ClearKey { key }));
87 }
88
89 Ok(None)
90}
91
92impl KeyExchangeMessage {
93 #[must_use]
95 pub fn render(&self) -> String {
96 match self {
97 Self::RsaRequest { n, e } => {
98 format!(
99 "RSA key exchange request={},{}.",
100 biguint_to_lower_hex(n),
101 biguint_to_lower_hex(e)
102 )
103 }
104 Self::RsaReply { c } => format!("RSA key exchange reply={}.", biguint_to_lower_hex(c)),
105 Self::DhRequest { g, p, a_pub } => format!(
106 "DH key exchange request={},{},{}.",
107 biguint_to_lower_hex(g),
108 biguint_to_lower_hex(p),
109 biguint_to_lower_hex(a_pub)
110 ),
111 Self::DhReply { b_pub } => {
112 format!("DH key exchange reply={}.", biguint_to_lower_hex(b_pub))
113 }
114 Self::ClearKey { key } => format!("Clear key exchange={}.", hex::encode(key)),
115 }
116 }
117}
118
119#[must_use]
121pub fn mod_pow(base: &BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
122 base.modpow(exp, modulus)
123}
124
125pub fn parse_biguint_hex(s: &str) -> Result<BigUint, KeyExchangeError> {
127 if s.is_empty() || !is_lower_hex(s) {
128 return Err(KeyExchangeError::NotLowerHex);
129 }
130 BigUint::from_str_radix(s, 16).map_err(|_| KeyExchangeError::HexDecode)
131}
132
133pub fn parse_hex_bytes(s: &str) -> Result<Vec<u8>, KeyExchangeError> {
135 if s.is_empty() || !s.len().is_multiple_of(2) || !is_lower_hex(s) {
136 return Err(KeyExchangeError::NotLowerHex);
137 }
138 hex::decode(s).map_err(|_| KeyExchangeError::HexDecode)
139}
140
141#[must_use]
143pub fn biguint_to_lower_hex(value: &BigUint) -> String {
144 let mut out = value.to_str_radix(16);
145 if out.is_empty() {
146 out.push('0');
147 }
148 out
149}
150
151fn is_lower_hex(s: &str) -> bool {
152 s.bytes()
153 .all(|b| b.is_ascii_digit() || (b'a'..=b'f').contains(&b))
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn parse_and_render_clear() {
162 let parsed = parse_key_exchange(b"Clear key exchange=666f6f.")
163 .expect("parse")
164 .expect("control");
165 assert_eq!(
166 parsed,
167 KeyExchangeMessage::ClearKey {
168 key: b"foo".to_vec()
169 }
170 );
171 assert_eq!(parsed.render(), "Clear key exchange=666f6f.");
172 }
173
174 #[test]
175 fn parse_rsa_request() {
176 let msg = parse_key_exchange(b"RSA key exchange request=0f,11.")
177 .expect("parse")
178 .expect("control");
179 match msg {
180 KeyExchangeMessage::RsaRequest { n, e } => {
181 assert_eq!(biguint_to_lower_hex(&n), "f");
182 assert_eq!(biguint_to_lower_hex(&e), "11");
183 }
184 _ => panic!("wrong variant"),
185 }
186 }
187
188 #[test]
189 fn rejects_leading_or_trailing_whitespace() {
190 assert!(
191 parse_key_exchange(b" RSA key exchange reply=ff.")
192 .expect("parse")
193 .is_none()
194 );
195 assert!(parse_key_exchange(b"RSA key exchange reply=ff. ").is_err());
196 }
197
198 #[test]
199 fn non_utf8_body_is_not_key_exchange_control() {
200 assert_eq!(parse_key_exchange(&[0xff, 0xfe]).expect("parse"), None);
201 }
202}