Skip to main content

guess/
lib.rs

1//! Wire-protocol classifier for TCP / TLS streams.
2//!
3//! Given the first bytes of a freshly accepted connection, [`classify`]
4//! runs a detector cascade and reports one of:
5//!
6//! - [`DetectedProtocol::TlsClientHello`] — the prefix is a complete
7//!   TLS `ClientHello` (parsed via rustls's `Acceptor`); SNI and ALPN
8//!   are extracted into [`PeekResult::tls`].
9//! - [`DetectedProtocol::Http2Preface`] — the bytes match the 24-byte
10//!   HTTP/2 connection preface from RFC 7540 §3.5.
11//! - [`DetectedProtocol::Http1`] — the bytes start with a known
12//!   HTTP/1 method and the request line carries an `HTTP/1.0` or
13//!   `HTTP/1.1` version marker.
14//! - [`DetectedProtocol::Unknown`] — every detector ruled itself out.
15//!
16//! The cascade is three-state: a detector can also say "I'd be willing
17//! to commit if I saw a few more bytes." When *any* detector returns
18//! that, [`classify`] surfaces `detected = None` so the caller can
19//! read more bytes (up to [`MAX_PEEK_BYTES`]) and call again. When
20//! every detector has ruled itself out, the result is `Unknown` and
21//! further reads cannot change the outcome.
22//!
23//! ## Types-only consumers
24//!
25//! The default `classify` feature pulls in `rustls` (for the TLS
26//! parse) and `memchr` (for the HTTP/1 scan). Disable defaults to
27//! get only the result types — useful when a downstream crate wants
28//! to *describe* a peek without performing one:
29//!
30//! ```toml
31//! guess = { version = "0.2", default-features = false }
32//! ```
33
34use bytes::Bytes;
35
36/// Outcome of one peek-buffer classification. `buffer` is the bytes
37/// that were classified (kept on the result so consumers can replay
38/// them to a downstream decoder via, e.g., `peeked-stream`).
39/// `detected` is `None` when at least one detector wants more bytes;
40/// the caller should read more and call [`classify`] again.
41#[derive(Clone, Debug)]
42pub struct PeekResult {
43	pub buffer: Bytes,
44	pub detected: Option<DetectedProtocol>,
45	pub tls: Option<TlsClientHello>,
46}
47
48#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
49pub enum DetectedProtocol {
50	TlsClientHello,
51	Http1,
52	Http2Preface,
53	QuicInitial,
54	Dns,
55	Unknown,
56}
57
58#[derive(Clone, Debug, Default)]
59pub struct TlsClientHello {
60	pub sni: Option<String>,
61	/// ALPN protocol IDs offered by the client in the `ClientHello`.
62	pub alpn: Vec<Vec<u8>>,
63}
64
65/// Maximum number of bytes a peek prelude should accumulate before
66/// declaring the connection's prefix `Unknown`. 8 KiB matches what
67/// most servers can read in a single non-blocking syscall and
68/// covers any realistic TLS `ClientHello` (with SNI + ALPN + GREASE).
69pub const MAX_PEEK_BYTES: usize = 8 * 1024;
70
71#[cfg(feature = "classify")]
72const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
73
74/// HTTP/1 request methods recognised by the HTTP/1 detector.
75/// Matched as a case-sensitive `<METHOD> ` prefix on the peek
76/// buffer.
77#[cfg(feature = "classify")]
78const HTTP1_METHODS: &[&[u8]] = &[
79	b"GET ",
80	b"POST ",
81	b"PUT ",
82	b"DELETE ",
83	b"HEAD ",
84	b"OPTIONS ",
85	b"PATCH ",
86	b"CONNECT ",
87	b"TRACE ",
88];
89
90/// HTTP/1 request-line version anchor — recognised as definitive when
91/// the byte immediately after this is `'0'` or `'1'`.
92#[cfg(feature = "classify")]
93const HTTP1_VERSION_PREFIX: &[u8] = b" HTTP/1.";
94
95/// Run the detector cascade against the current peek buffer.
96///
97/// Returns `Some(DetectedProtocol::*)` for a definitive match,
98/// `None` (in [`PeekResult::detected`]) when *some* detector is
99/// willing to wait for more bytes (the caller should keep reading
100/// until it hits [`MAX_PEEK_BYTES`] or the read times out), and
101/// `Some(DetectedProtocol::Unknown)` when every detector has ruled
102/// itself out — at that point further reads cannot change the
103/// outcome.
104#[cfg(feature = "classify")]
105#[must_use]
106pub fn classify(buf: &[u8]) -> PeekResult {
107	let buffer = Bytes::copy_from_slice(buf);
108
109	if buf.is_empty() {
110		return PeekResult { buffer, detected: None, tls: None };
111	}
112
113	match detect_tls(buf) {
114		DetectorOutcome::Match => {
115			let tls = parse_client_hello(buf);
116			return PeekResult {
117				buffer,
118				detected: Some(DetectedProtocol::TlsClientHello),
119				tls: Some(tls),
120			};
121		}
122		DetectorOutcome::NeedMore => {
123			return PeekResult { buffer, detected: None, tls: None };
124		}
125		DetectorOutcome::NoMatch => {}
126	}
127
128	match detect_h2_preface(buf) {
129		DetectorOutcome::Match => {
130			return PeekResult { buffer, detected: Some(DetectedProtocol::Http2Preface), tls: None };
131		}
132		DetectorOutcome::NeedMore => {
133			return PeekResult { buffer, detected: None, tls: None };
134		}
135		DetectorOutcome::NoMatch => {}
136	}
137
138	match detect_http1(buf) {
139		DetectorOutcome::Match => {
140			return PeekResult { buffer, detected: Some(DetectedProtocol::Http1), tls: None };
141		}
142		DetectorOutcome::NeedMore => {
143			return PeekResult { buffer, detected: None, tls: None };
144		}
145		DetectorOutcome::NoMatch => {}
146	}
147
148	// Every detector ruled itself out — the prefix is opaque to us.
149	PeekResult { buffer, detected: Some(DetectedProtocol::Unknown), tls: None }
150}
151
152#[cfg(feature = "classify")]
153#[derive(Copy, Clone, Eq, PartialEq, Debug)]
154enum DetectorOutcome {
155	Match,
156	NeedMore,
157	NoMatch,
158}
159
160#[cfg(feature = "classify")]
161fn detect_tls(buf: &[u8]) -> DetectorOutcome {
162	if buf.first() != Some(&0x16) {
163		return DetectorOutcome::NoMatch;
164	}
165	let mut acceptor = rustls::server::Acceptor::default();
166	let mut input: &[u8] = buf;
167	if acceptor.read_tls(&mut input).is_err() {
168		return DetectorOutcome::NoMatch;
169	}
170	match acceptor.accept() {
171		Ok(Some(_)) => DetectorOutcome::Match,
172		Ok(None) => DetectorOutcome::NeedMore,
173		Err(_) => DetectorOutcome::NoMatch,
174	}
175}
176
177#[cfg(feature = "classify")]
178fn detect_h2_preface(buf: &[u8]) -> DetectorOutcome {
179	if buf.len() >= H2_PREFACE.len() {
180		return if buf.starts_with(H2_PREFACE) {
181			DetectorOutcome::Match
182		} else {
183			DetectorOutcome::NoMatch
184		};
185	}
186	if H2_PREFACE.starts_with(buf) { DetectorOutcome::NeedMore } else { DetectorOutcome::NoMatch }
187}
188
189#[cfg(feature = "classify")]
190fn detect_http1(buf: &[u8]) -> DetectorOutcome {
191	let mut full_method_match = false;
192	let mut prefix_of_method = false;
193	for m in HTTP1_METHODS {
194		if buf.starts_with(m) {
195			full_method_match = true;
196			break;
197		}
198		if buf.len() < m.len() && m.starts_with(buf) {
199			prefix_of_method = true;
200		}
201	}
202	if !full_method_match {
203		return if prefix_of_method { DetectorOutcome::NeedMore } else { DetectorOutcome::NoMatch };
204	}
205
206	// Method+SP matched. Look ahead for ` HTTP/1.[01]`. A `\r\n` seen
207	// before the version anchor means the request line ended without
208	// a known HTTP/1 marker (HTTP/0.9 or junk) — no match.
209	let cr_lf = memchr::memmem::find(buf, b"\r\n");
210	let version_at = memchr::memmem::find(buf, HTTP1_VERSION_PREFIX);
211	match (version_at, cr_lf) {
212		(Some(v), Some(rn)) if rn < v => DetectorOutcome::NoMatch,
213		(Some(v), _) => {
214			let digit_idx = v + HTTP1_VERSION_PREFIX.len();
215			match buf.get(digit_idx).copied() {
216				Some(b'0' | b'1') => DetectorOutcome::Match,
217				Some(_) => DetectorOutcome::NoMatch,
218				None => DetectorOutcome::NeedMore,
219			}
220		}
221		(None, Some(_)) => DetectorOutcome::NoMatch,
222		(None, None) => DetectorOutcome::NeedMore,
223	}
224}
225
226/// Parse a complete `ClientHello` out of `buf`. Caller has already
227/// confirmed [`detect_tls`] returned a `Match` for the same bytes;
228/// on the (theoretically unreachable) re-parse failure path we fall
229/// back to an empty `TlsClientHello` rather than panic.
230#[cfg(feature = "classify")]
231fn parse_client_hello(buf: &[u8]) -> TlsClientHello {
232	let mut acceptor = rustls::server::Acceptor::default();
233	let mut input: &[u8] = buf;
234	if acceptor.read_tls(&mut input).is_err() {
235		return TlsClientHello::default();
236	}
237	let Ok(Some(accepted)) = acceptor.accept() else {
238		return TlsClientHello::default();
239	};
240	let hello = accepted.client_hello();
241	let sni = hello.server_name().map(str::to_ascii_lowercase);
242	let alpn: Vec<Vec<u8>> =
243		hello.alpn().map_or_else(Vec::new, |it| it.map(<[u8]>::to_vec).collect());
244	TlsClientHello { sni, alpn }
245}
246
247#[cfg(test)]
248mod tests {
249	use super::*;
250
251	#[test]
252	fn peek_result_is_clone_send_sync_static() {
253		fn assert_bounds<T: Clone + Send + Sync + 'static>() {}
254		assert_bounds::<PeekResult>();
255	}
256
257	#[test]
258	fn detected_protocol_variants_are_distinct() {
259		let all = [
260			DetectedProtocol::TlsClientHello,
261			DetectedProtocol::Http1,
262			DetectedProtocol::Http2Preface,
263			DetectedProtocol::QuicInitial,
264			DetectedProtocol::Dns,
265			DetectedProtocol::Unknown,
266		];
267		for (i, a) in all.iter().enumerate() {
268			for (j, b) in all.iter().enumerate() {
269				assert_eq!(a == b, i == j);
270			}
271		}
272	}
273
274	#[test]
275	fn tls_client_hello_default_is_empty() {
276		let h = TlsClientHello::default();
277		assert!(h.sni.is_none());
278		assert!(h.alpn.is_empty());
279	}
280
281	#[test]
282	fn max_peek_bytes_is_8k() {
283		assert_eq!(MAX_PEEK_BYTES, 8 * 1024);
284	}
285
286	#[cfg(feature = "classify")]
287	mod classify {
288		use super::*;
289
290		fn classify_short(s: &[u8]) -> PeekResult {
291			classify(s)
292		}
293
294		#[test]
295		fn classify_empty_buffer_is_indeterminate() {
296			let r = classify(&[]);
297			assert!(r.detected.is_none());
298			assert!(r.tls.is_none());
299			assert!(r.buffer.is_empty());
300		}
301
302		#[test]
303		fn classify_http1_get_request_line_matches_http1() {
304			let r = classify_short(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n");
305			assert_eq!(r.detected, Some(DetectedProtocol::Http1));
306			assert!(r.tls.is_none());
307		}
308
309		#[test]
310		fn classify_http1_post_request_line_matches_http1() {
311			let r = classify_short(b"POST /x HTTP/1.0\r\n");
312			assert_eq!(r.detected, Some(DetectedProtocol::Http1));
313		}
314
315		#[test]
316		fn classify_http1_partial_method_is_indeterminate() {
317			// `G` is a prefix of `GET ` — caller should read more bytes.
318			let r = classify_short(b"G");
319			assert!(r.detected.is_none());
320		}
321
322		#[test]
323		fn classify_http1_http_0_9_request_line_does_not_match_http1() {
324			// `GET /\r\n` is a valid HTTP/0.9 request — no version
325			// marker before `\r\n`. Detector must reject it cleanly.
326			let r = classify_short(b"GET /\r\n");
327			assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
328		}
329
330		#[test]
331		fn classify_http1_unknown_method_is_unknown() {
332			let r = classify_short(b"FOO /index HTTP/1.1\r\n");
333			assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
334		}
335
336		#[test]
337		fn classify_http2_preface_exact_match() {
338			let r = classify_short(H2_PREFACE);
339			assert_eq!(r.detected, Some(DetectedProtocol::Http2Preface));
340		}
341
342		#[test]
343		fn classify_http2_preface_partial_is_indeterminate() {
344			let r = classify_short(b"PRI * HTTP/2.0\r\n");
345			assert!(r.detected.is_none());
346		}
347
348		#[test]
349		fn classify_http2_preface_close_but_wrong_byte_is_unknown() {
350			let mut bad = H2_PREFACE.to_vec();
351			*bad.last_mut().expect("preface non-empty") = b'x';
352			let r = classify(&bad);
353			assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
354		}
355
356		#[test]
357		fn classify_tls_client_hello_matches_and_extracts_sni_alpn() {
358			install_crypto();
359			let bytes = build_client_hello_bytes("api.example.com", &[b"h2".to_vec()]);
360			let r = classify(&bytes);
361			assert_eq!(r.detected, Some(DetectedProtocol::TlsClientHello));
362			let tls = r.tls.expect("tls hello populated");
363			assert_eq!(tls.sni.as_deref(), Some("api.example.com"));
364			assert!(tls.alpn.iter().any(|p| p == b"h2"), "alpn includes h2: {:?}", tls.alpn);
365		}
366
367		#[test]
368		fn classify_tls_truncated_is_indeterminate() {
369			install_crypto();
370			let bytes = build_client_hello_bytes("api.example.com", &[b"h2".to_vec()]);
371			let r = classify(&bytes[..6]);
372			assert!(r.detected.is_none());
373		}
374
375		#[test]
376		fn classify_tls_byte_then_garbage_falls_back_to_unknown() {
377			let mut buf = vec![0x16u8];
378			buf.extend(std::iter::repeat_n(0xFFu8, 64));
379			let r = classify(&buf);
380			assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
381		}
382
383		#[test]
384		fn classify_random_8kib_is_unknown() {
385			let buf: Vec<u8> = (0..MAX_PEEK_BYTES).map(|i| u8::try_from(i & 0xFF).unwrap()).collect();
386			let r = classify(&buf);
387			assert_eq!(r.detected, Some(DetectedProtocol::Unknown));
388		}
389
390		#[test]
391		fn h2_preface_constant_matches_spec() {
392			assert_eq!(H2_PREFACE, b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n");
393			assert_eq!(H2_PREFACE.len(), 24);
394		}
395
396		fn install_crypto() {
397			let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
398		}
399
400		#[derive(Debug)]
401		struct NoVerify;
402		impl rustls::client::danger::ServerCertVerifier for NoVerify {
403			fn verify_server_cert(
404				&self,
405				_end_entity: &rustls::pki_types::CertificateDer<'_>,
406				_intermediates: &[rustls::pki_types::CertificateDer<'_>],
407				_server_name: &rustls::pki_types::ServerName<'_>,
408				_ocsp_response: &[u8],
409				_now: rustls::pki_types::UnixTime,
410			) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
411				Ok(rustls::client::danger::ServerCertVerified::assertion())
412			}
413			fn verify_tls12_signature(
414				&self,
415				_message: &[u8],
416				_cert: &rustls::pki_types::CertificateDer<'_>,
417				_dss: &rustls::DigitallySignedStruct,
418			) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
419				Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
420			}
421			fn verify_tls13_signature(
422				&self,
423				_message: &[u8],
424				_cert: &rustls::pki_types::CertificateDer<'_>,
425				_dss: &rustls::DigitallySignedStruct,
426			) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
427				Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
428			}
429			fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
430				rustls::crypto::CryptoProvider::get_default()
431					.expect("crypto provider")
432					.signature_verification_algorithms
433					.supported_schemes()
434			}
435		}
436
437		/// Synthesise a TLS `ClientHello` by running rustls's own
438		/// client-side state machine and capturing the bytes it would
439		/// write to a hypothetical socket.
440		fn build_client_hello_bytes(server_name: &str, alpn: &[Vec<u8>]) -> Vec<u8> {
441			use std::sync::Arc;
442
443			let mut config = rustls::ClientConfig::builder()
444				.dangerous()
445				.with_custom_certificate_verifier(Arc::new(NoVerify))
446				.with_no_client_auth();
447			config.alpn_protocols = alpn.to_vec();
448			let server =
449				rustls::pki_types::ServerName::try_from(server_name.to_owned()).expect("server name");
450			let mut conn = rustls::ClientConnection::new(Arc::new(config), server).expect("client conn");
451			let mut out = Vec::new();
452			conn.write_tls(&mut out).expect("write_tls");
453			out
454		}
455	}
456}