1use bytes::Bytes;
35
36#[derive(Clone, Debug)]
42pub struct PeekResult {
43 pub buffer: Bytes,
44 pub detected: Option<DetectedProtocol>,
45 pub tls: Option<TlsClientHello>,
46}
47
48#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
49pub enum DetectedProtocol {
50 TlsClientHello,
51 Http1,
52 Http2Preface,
53 QuicInitial,
54 Dns,
55 Unknown,
56}
57
58#[derive(Clone, Debug, Default)]
59pub struct TlsClientHello {
60 pub sni: Option<String>,
61 pub alpn: Vec<Vec<u8>>,
63}
64
65pub const MAX_PEEK_BYTES: usize = 8 * 1024;
70
71#[cfg(feature = "classify")]
72const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
73
74#[cfg(feature = "classify")]
78const HTTP1_METHODS: &[&[u8]] = &[
79 b"GET ",
80 b"POST ",
81 b"PUT ",
82 b"DELETE ",
83 b"HEAD ",
84 b"OPTIONS ",
85 b"PATCH ",
86 b"CONNECT ",
87 b"TRACE ",
88];
89
90#[cfg(feature = "classify")]
93const HTTP1_VERSION_PREFIX: &[u8] = b" HTTP/1.";
94
95#[cfg(feature = "classify")]
105#[must_use]
106pub fn classify(buf: &[u8]) -> PeekResult {
107 let buffer = Bytes::copy_from_slice(buf);
108
109 if buf.is_empty() {
110 return PeekResult { buffer, detected: None, tls: None };
111 }
112
113 match detect_tls(buf) {
114 DetectorOutcome::Match => {
115 let tls = parse_client_hello(buf);
116 return PeekResult {
117 buffer,
118 detected: Some(DetectedProtocol::TlsClientHello),
119 tls: Some(tls),
120 };
121 }
122 DetectorOutcome::NeedMore => {
123 return PeekResult { buffer, detected: None, tls: None };
124 }
125 DetectorOutcome::NoMatch => {}
126 }
127
128 match detect_h2_preface(buf) {
129 DetectorOutcome::Match => {
130 return PeekResult { buffer, detected: Some(DetectedProtocol::Http2Preface), tls: None };
131 }
132 DetectorOutcome::NeedMore => {
133 return PeekResult { buffer, detected: None, tls: None };
134 }
135 DetectorOutcome::NoMatch => {}
136 }
137
138 match detect_http1(buf) {
139 DetectorOutcome::Match => {
140 return PeekResult { buffer, detected: Some(DetectedProtocol::Http1), tls: None };
141 }
142 DetectorOutcome::NeedMore => {
143 return PeekResult { buffer, detected: None, tls: None };
144 }
145 DetectorOutcome::NoMatch => {}
146 }
147
148 PeekResult { buffer, detected: Some(DetectedProtocol::Unknown), tls: None }
150}
151
152#[cfg(feature = "classify")]
153#[derive(Copy, Clone, Eq, PartialEq, Debug)]
154enum DetectorOutcome {
155 Match,
156 NeedMore,
157 NoMatch,
158}
159
160#[cfg(feature = "classify")]
161fn detect_tls(buf: &[u8]) -> DetectorOutcome {
162 if buf.first() != Some(&0x16) {
163 return DetectorOutcome::NoMatch;
164 }
165 let mut acceptor = rustls::server::Acceptor::default();
166 let mut input: &[u8] = buf;
167 if acceptor.read_tls(&mut input).is_err() {
168 return DetectorOutcome::NoMatch;
169 }
170 match acceptor.accept() {
171 Ok(Some(_)) => DetectorOutcome::Match,
172 Ok(None) => DetectorOutcome::NeedMore,
173 Err(_) => DetectorOutcome::NoMatch,
174 }
175}
176
177#[cfg(feature = "classify")]
178fn detect_h2_preface(buf: &[u8]) -> DetectorOutcome {
179 if buf.len() >= H2_PREFACE.len() {
180 return if buf.starts_with(H2_PREFACE) {
181 DetectorOutcome::Match
182 } else {
183 DetectorOutcome::NoMatch
184 };
185 }
186 if H2_PREFACE.starts_with(buf) { DetectorOutcome::NeedMore } else { DetectorOutcome::NoMatch }
187}
188
189#[cfg(feature = "classify")]
190fn detect_http1(buf: &[u8]) -> DetectorOutcome {
191 let mut full_method_match = false;
192 let mut prefix_of_method = false;
193 for m in HTTP1_METHODS {
194 if buf.starts_with(m) {
195 full_method_match = true;
196 break;
197 }
198 if buf.len() < m.len() && m.starts_with(buf) {
199 prefix_of_method = true;
200 }
201 }
202 if !full_method_match {
203 return if prefix_of_method { DetectorOutcome::NeedMore } else { DetectorOutcome::NoMatch };
204 }
205
206 let cr_lf = memchr::memmem::find(buf, b"\r\n");
210 let version_at = memchr::memmem::find(buf, HTTP1_VERSION_PREFIX);
211 match (version_at, cr_lf) {
212 (Some(v), Some(rn)) if rn < v => DetectorOutcome::NoMatch,
213 (Some(v), _) => {
214 let digit_idx = v + HTTP1_VERSION_PREFIX.len();
215 match buf.get(digit_idx).copied() {
216 Some(b'0' | b'1') => DetectorOutcome::Match,
217 Some(_) => DetectorOutcome::NoMatch,
218 None => DetectorOutcome::NeedMore,
219 }
220 }
221 (None, Some(_)) => DetectorOutcome::NoMatch,
222 (None, None) => DetectorOutcome::NeedMore,
223 }
224}
225
226#[cfg(feature = "classify")]
231fn parse_client_hello(buf: &[u8]) -> TlsClientHello {
232 let mut acceptor = rustls::server::Acceptor::default();
233 let mut input: &[u8] = buf;
234 if acceptor.read_tls(&mut input).is_err() {
235 return TlsClientHello::default();
236 }
237 let Ok(Some(accepted)) = acceptor.accept() else {
238 return TlsClientHello::default();
239 };
240 let hello = accepted.client_hello();
241 let sni = hello.server_name().map(str::to_ascii_lowercase);
242 let alpn: Vec<Vec<u8>> =
243 hello.alpn().map_or_else(Vec::new, |it| it.map(<[u8]>::to_vec).collect());
244 TlsClientHello { sni, alpn }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn peek_result_is_clone_send_sync_static() {
253 fn assert_bounds<T: Clone + Send + Sync + 'static>() {}
254 assert_bounds::<PeekResult>();
255 }
256
257 #[test]
258 fn detected_protocol_variants_are_distinct() {
259 let all = [
260 DetectedProtocol::TlsClientHello,
261 DetectedProtocol::Http1,
262 DetectedProtocol::Http2Preface,
263 DetectedProtocol::QuicInitial,
264 DetectedProtocol::Dns,
265 DetectedProtocol::Unknown,
266 ];
267 for (i, a) in all.iter().enumerate() {
268 for (j, b) in all.iter().enumerate() {
269 assert_eq!(a == b, i == j);
270 }
271 }
272 }
273
274 #[test]
275 fn tls_client_hello_default_is_empty() {
276 let h = TlsClientHello::default();
277 assert!(h.sni.is_none());
278 assert!(h.alpn.is_empty());
279 }
280
281 #[test]
282 fn max_peek_bytes_is_8k() {
283 assert_eq!(MAX_PEEK_BYTES, 8 * 1024);
284 }
285
286 #[cfg(feature = "classify")]
287 mod classify {
288 use super::*;
289
290 fn classify_short(s: &[u8]) -> PeekResult {
291 classify(s)
292 }
293
294 #[test]
295 fn classify_empty_buffer_is_indeterminate() {
296 let r = classify(&[]);
297 assert!(r.detected.is_none());
298 assert!(r.tls.is_none());
299 assert!(r.buffer.is_empty());
300 }
301
302 #[test]
303 fn classify_http1_get_request_line_matches_http1() {
304 let r = classify_short(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n");
305 assert_eq!(r.detected, Some(DetectedProtocol::Http1));
306 assert!(r.tls.is_none());
307 }
308
309 #[test]
310 fn classify_http1_post_request_line_matches_http1() {
311 let r = classify_short(b"POST /x HTTP/1.0\r\n");
312 assert_eq!(r.detected, Some(DetectedProtocol::Http1));
313 }
314
315 #[test]
316 fn classify_http1_partial_method_is_indeterminate() {
317 let r = classify_short(b"G");
319 assert!(r.detected.is_none());
320 }
321
322 #[test]
323 fn classify_http1_http_0_9_request_line_does_not_match_http1() {
324 let r = classify_short(b"GET /\r\n");
327 assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
328 }
329
330 #[test]
331 fn classify_http1_unknown_method_is_unknown() {
332 let r = classify_short(b"FOO /index HTTP/1.1\r\n");
333 assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
334 }
335
336 #[test]
337 fn classify_http2_preface_exact_match() {
338 let r = classify_short(H2_PREFACE);
339 assert_eq!(r.detected, Some(DetectedProtocol::Http2Preface));
340 }
341
342 #[test]
343 fn classify_http2_preface_partial_is_indeterminate() {
344 let r = classify_short(b"PRI * HTTP/2.0\r\n");
345 assert!(r.detected.is_none());
346 }
347
348 #[test]
349 fn classify_http2_preface_close_but_wrong_byte_is_unknown() {
350 let mut bad = H2_PREFACE.to_vec();
351 *bad.last_mut().expect("preface non-empty") = b'x';
352 let r = classify(&bad);
353 assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
354 }
355
356 #[test]
357 fn classify_tls_client_hello_matches_and_extracts_sni_alpn() {
358 install_crypto();
359 let bytes = build_client_hello_bytes("api.example.com", &[b"h2".to_vec()]);
360 let r = classify(&bytes);
361 assert_eq!(r.detected, Some(DetectedProtocol::TlsClientHello));
362 let tls = r.tls.expect("tls hello populated");
363 assert_eq!(tls.sni.as_deref(), Some("api.example.com"));
364 assert!(tls.alpn.iter().any(|p| p == b"h2"), "alpn includes h2: {:?}", tls.alpn);
365 }
366
367 #[test]
368 fn classify_tls_truncated_is_indeterminate() {
369 install_crypto();
370 let bytes = build_client_hello_bytes("api.example.com", &[b"h2".to_vec()]);
371 let r = classify(&bytes[..6]);
372 assert!(r.detected.is_none());
373 }
374
375 #[test]
376 fn classify_tls_byte_then_garbage_falls_back_to_unknown() {
377 let mut buf = vec![0x16u8];
378 buf.extend(std::iter::repeat_n(0xFFu8, 64));
379 let r = classify(&buf);
380 assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
381 }
382
383 #[test]
384 fn classify_random_8kib_is_unknown() {
385 let buf: Vec<u8> = (0..MAX_PEEK_BYTES).map(|i| u8::try_from(i & 0xFF).unwrap()).collect();
386 let r = classify(&buf);
387 assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
388 }
389
390 #[test]
391 fn h2_preface_constant_matches_spec() {
392 assert_eq!(H2_PREFACE, b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n");
393 assert_eq!(H2_PREFACE.len(), 24);
394 }
395
396 fn install_crypto() {
397 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
398 }
399
400 #[derive(Debug)]
401 struct NoVerify;
402 impl rustls::client::danger::ServerCertVerifier for NoVerify {
403 fn verify_server_cert(
404 &self,
405 _end_entity: &rustls::pki_types::CertificateDer<'_>,
406 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
407 _server_name: &rustls::pki_types::ServerName<'_>,
408 _ocsp_response: &[u8],
409 _now: rustls::pki_types::UnixTime,
410 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
411 Ok(rustls::client::danger::ServerCertVerified::assertion())
412 }
413 fn verify_tls12_signature(
414 &self,
415 _message: &[u8],
416 _cert: &rustls::pki_types::CertificateDer<'_>,
417 _dss: &rustls::DigitallySignedStruct,
418 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
419 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
420 }
421 fn verify_tls13_signature(
422 &self,
423 _message: &[u8],
424 _cert: &rustls::pki_types::CertificateDer<'_>,
425 _dss: &rustls::DigitallySignedStruct,
426 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
427 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
428 }
429 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
430 rustls::crypto::CryptoProvider::get_default()
431 .expect("crypto provider")
432 .signature_verification_algorithms
433 .supported_schemes()
434 }
435 }
436
437 fn build_client_hello_bytes(server_name: &str, alpn: &[Vec<u8>]) -> Vec<u8> {
441 use std::sync::Arc;
442
443 let mut config = rustls::ClientConfig::builder()
444 .dangerous()
445 .with_custom_certificate_verifier(Arc::new(NoVerify))
446 .with_no_client_auth();
447 config.alpn_protocols = alpn.to_vec();
448 let server =
449 rustls::pki_types::ServerName::try_from(server_name.to_owned()).expect("server name");
450 let mut conn = rustls::ClientConnection::new(Arc::new(config), server).expect("client conn");
451 let mut out = Vec::new();
452 conn.write_tls(&mut out).expect("write_tls");
453 out
454 }
455 }
456}