cellos-supervisor 0.5.1

CellOS execution-cell runner — boots cells in Firecracker microVMs or gVisor, enforces narrow typed authority, emits signed CloudEvents.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
//! Pure-byte TLS ClientHello SNI extractor — no TLS decoding crates pulled in.
//!
//! The SEC-22 Phase 2 SNI proxy peeks the first bytes of a TCP connection and,
//! when they look like TLS, extracts the `server_name` extension's first
//! `host_name` entry (RFC 6066). The proxy NEVER terminates TLS; this module
//! walks the ClientHello as plain bytes.
//!
//! # Wire format walked here
//!
//! - **TLS record header (5 bytes)**: `type(1) version(2) length(2)`. We
//!   require `type == 22` (handshake) and `length <= 16384`.
//! - **Handshake header (4 bytes)**: `type(1) length(3)`. We require
//!   `type == 1` (ClientHello).
//! - **ClientHello body**:
//!   - `legacy_version(2)`
//!   - `random(32)`
//!   - `legacy_session_id` — 1-byte length + that many bytes (max 32)
//!   - `cipher_suites` — 2-byte length + that many bytes (must be even)
//!   - `legacy_compression_methods` — 1-byte length + that many bytes
//!   - `extensions` — 2-byte length + that many bytes
//! - Inside extensions, `server_name` (type 0) carries:
//!   - 2-byte `server_name_list` length
//!   - one or more `(name_type:1, name:length-prefixed-2)` entries; we return
//!     the first `host_name` (`name_type == 0`).
//!
//! # Errors
//!
//! Any truncation, oversized length field, or wrong type byte is reported as
//! [`SniParseError`]. `Ok(None)` is reserved for well-formed ClientHellos that
//! simply did not include the `server_name` extension (RFC 6066 explicitly
//! permits omission).

use std::fmt;

/// Maximum TLS record body length we accept. RFC 5246 §6.2.1 sets the legal
/// ceiling at 2^14 (16384) bytes for TLSCiphertext — anything larger is a
/// malformed record and we reject defensively.
pub const MAX_RECORD_LEN: usize = 16384;

/// Errors produced by [`extract_sni`]. Distinct variants exist so the proxy
/// can map specific failure modes to richer reasonCodes if a future Phase
/// breaks malformed-input handling out of the single `l7_unknown_protocol`
/// reasonCode.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum SniParseError {
    /// Buffer is too short to contain even the TLS record header (5 bytes).
    TooShort,
    /// First byte is not 22 (handshake content type) — the connection isn't TLS.
    NotHandshake,
    /// Record `length` field exceeds [`MAX_RECORD_LEN`].
    RecordTooLarge,
    /// Handshake type byte is not 1 (ClientHello).
    NotClientHello,
    /// A length field referenced bytes past the end of the buffer.
    Truncated,
    /// A length field declared a value larger than its container allows.
    InvalidLength,
}

impl fmt::Display for SniParseError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            SniParseError::TooShort => write!(f, "buffer too short for TLS record header"),
            SniParseError::NotHandshake => write!(f, "first byte is not handshake (type=22)"),
            SniParseError::RecordTooLarge => {
                write!(f, "TLS record length exceeds {MAX_RECORD_LEN} bytes")
            }
            SniParseError::NotClientHello => {
                write!(f, "handshake type is not ClientHello (type=1)")
            }
            SniParseError::Truncated => write!(f, "ClientHello truncated mid-field"),
            SniParseError::InvalidLength => {
                write!(f, "ClientHello length field overflows container")
            }
        }
    }
}

impl std::error::Error for SniParseError {}

/// Read `n` bytes starting at `*idx`, advance `*idx`, and return a slice. On
/// truncation returns [`SniParseError::Truncated`]. Pure helper — keeps the
/// extractor flow linear and immune to off-by-one slice indexing.
fn take<'a>(buf: &'a [u8], idx: &mut usize, n: usize) -> Result<&'a [u8], SniParseError> {
    let end = idx.checked_add(n).ok_or(SniParseError::InvalidLength)?;
    if end > buf.len() {
        return Err(SniParseError::Truncated);
    }
    let s = &buf[*idx..end];
    *idx = end;
    Ok(s)
}

fn read_u8(buf: &[u8], idx: &mut usize) -> Result<u8, SniParseError> {
    Ok(take(buf, idx, 1)?[0])
}

fn read_u16(buf: &[u8], idx: &mut usize) -> Result<u16, SniParseError> {
    let s = take(buf, idx, 2)?;
    Ok(u16::from_be_bytes([s[0], s[1]]))
}

fn read_u24(buf: &[u8], idx: &mut usize) -> Result<u32, SniParseError> {
    let s = take(buf, idx, 3)?;
    Ok(u32::from_be_bytes([0, s[0], s[1], s[2]]))
}

/// Extract the first `host_name` (RFC 6066 type 0) entry from a buffer that
/// is expected to start with a TLS record carrying a ClientHello.
///
/// - Returns `Ok(Some(name))` on a well-formed ClientHello whose
///   `server_name` extension contains at least one `host_name`. The returned
///   name is lowercased and any single trailing dot is stripped.
/// - Returns `Ok(None)` on a well-formed ClientHello that did not include a
///   `server_name` extension (or whose extension carried zero `host_name`
///   entries).
/// - Returns `Err(SniParseError)` on any structural anomaly (truncation,
///   oversized lengths, wrong record type).
///
/// The function does NOT validate the hostname against any allowlist — that
/// is [`crate::sni_proxy`]'s responsibility downstream.
pub fn extract_sni(client_hello: &[u8]) -> Result<Option<String>, SniParseError> {
    let mut idx = 0usize;

    // ---- TLS record header ----
    if client_hello.len() < 5 {
        return Err(SniParseError::TooShort);
    }
    let record_type = read_u8(client_hello, &mut idx)?;
    if record_type != 22 {
        return Err(SniParseError::NotHandshake);
    }
    // Skip legacy version (2 bytes).
    let _ = read_u16(client_hello, &mut idx)?;
    let record_len = read_u16(client_hello, &mut idx)? as usize;
    if record_len > MAX_RECORD_LEN {
        return Err(SniParseError::RecordTooLarge);
    }
    // The record body must fit in what we've buffered. ClientHello can be
    // split across multiple TLS records in pathological cases, but every
    // mainstream implementation emits it in a single record. Phase 2 only
    // inspects the first record's worth of bytes; multi-record ClientHello
    // is rejected as truncated.
    let record_end = idx
        .checked_add(record_len)
        .ok_or(SniParseError::InvalidLength)?;
    if record_end > client_hello.len() {
        return Err(SniParseError::Truncated);
    }
    let record_end_clamp = record_end.min(client_hello.len());

    // ---- Handshake header ----
    let hs_type = read_u8(client_hello, &mut idx)?;
    if hs_type != 1 {
        return Err(SniParseError::NotClientHello);
    }
    let hs_len = read_u24(client_hello, &mut idx)? as usize;
    let hs_end = idx
        .checked_add(hs_len)
        .ok_or(SniParseError::InvalidLength)?;
    if hs_end > record_end_clamp {
        return Err(SniParseError::InvalidLength);
    }

    // ---- ClientHello body ----
    // legacy_version (2) + random (32)
    if hs_end - idx < 2 + 32 {
        return Err(SniParseError::Truncated);
    }
    let _ = read_u16(client_hello, &mut idx)?;
    let _ = take(client_hello, &mut idx, 32)?;

    // legacy_session_id: 1-byte length, max 32.
    let sid_len = read_u8(client_hello, &mut idx)? as usize;
    if sid_len > 32 {
        return Err(SniParseError::InvalidLength);
    }
    let _ = take(client_hello, &mut idx, sid_len)?;

    // cipher_suites: 2-byte length; must be even (each suite is 2 bytes) and
    // bounded by handshake-size envelope.
    let cs_len = read_u16(client_hello, &mut idx)? as usize;
    if !cs_len.is_multiple_of(2) {
        return Err(SniParseError::InvalidLength);
    }
    if idx + cs_len > hs_end {
        return Err(SniParseError::InvalidLength);
    }
    let _ = take(client_hello, &mut idx, cs_len)?;

    // legacy_compression_methods: 1-byte length.
    let comp_len = read_u8(client_hello, &mut idx)? as usize;
    if idx + comp_len > hs_end {
        return Err(SniParseError::InvalidLength);
    }
    let _ = take(client_hello, &mut idx, comp_len)?;

    // extensions: optional (TLS 1.0 ClientHellos may omit). If we've reached
    // hs_end, there are no extensions and therefore no SNI.
    if idx == hs_end {
        return Ok(None);
    }
    let ext_total = read_u16(client_hello, &mut idx)? as usize;
    if idx + ext_total > hs_end {
        return Err(SniParseError::InvalidLength);
    }
    let ext_end = idx + ext_total;

    // Walk extensions; find type 0 (server_name).
    while idx + 4 <= ext_end {
        let ext_type = read_u16(client_hello, &mut idx)?;
        let ext_len = read_u16(client_hello, &mut idx)? as usize;
        if idx + ext_len > ext_end {
            return Err(SniParseError::InvalidLength);
        }
        if ext_type == 0 {
            // server_name extension body.
            return parse_server_name_extension(&client_hello[idx..idx + ext_len]);
        }
        idx += ext_len;
    }
    Ok(None)
}

/// Parse the body of the `server_name` extension and return the first
/// `host_name` entry. RFC 6066 §3:
///
/// ```text
///   ServerNameList   server_name_list   (2-byte length-prefixed)
///   ServerName       server_name        (1-byte name_type + opaque)
///   HostName         host_name          (2-byte length-prefixed bytes)
/// ```
fn parse_server_name_extension(body: &[u8]) -> Result<Option<String>, SniParseError> {
    let mut idx = 0usize;
    let list_len = read_u16(body, &mut idx)? as usize;
    if idx + list_len > body.len() {
        return Err(SniParseError::InvalidLength);
    }
    let list_end = idx + list_len;

    while idx + 3 <= list_end {
        let name_type = read_u8(body, &mut idx)?;
        let name_len = read_u16(body, &mut idx)? as usize;
        if idx + name_len > list_end {
            return Err(SniParseError::InvalidLength);
        }
        if name_type == 0 {
            // host_name. RFC 6066 says ASCII; lowercase to canonical form.
            let raw = &body[idx..idx + name_len];
            // Reject empty names — downstream allowlist match would always
            // fail, but the schema requires minLength=1 for `sniHost`.
            if raw.is_empty() {
                return Ok(None);
            }
            // Best-effort UTF-8 → lowercase. RFC 6066 mandates ASCII for
            // `host_name`; non-ASCII inputs are coerced via from_utf8_lossy
            // and the downstream allowlist match will simply not find them.
            let mut s = String::from_utf8_lossy(raw).to_string();
            s.make_ascii_lowercase();
            // Strip a single trailing dot (FQDN-form clients sometimes send
            // `host.` as the SNI value).
            if s.ends_with('.') {
                s.pop();
            }
            return Ok(Some(s));
        }
        idx += name_len;
    }
    Ok(None)
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Build a minimal but well-formed ClientHello carrying SNI extensions.
    /// Returns the full TLS record bytes. Inspired by `dns_proxy/parser.rs`'s
    /// synth-fixture pattern.
    fn build_client_hello(snis: &[&str]) -> Vec<u8> {
        // ---- ClientHello body ----
        let mut body = Vec::new();
        body.extend_from_slice(&[0x03, 0x03]); // legacy_version TLS 1.2
        body.extend_from_slice(&[0u8; 32]); // random
        body.push(0); // session id length 0
        body.extend_from_slice(&[0x00, 0x02, 0x13, 0x01]); // cipher_suites: 1 suite (TLS_AES_128_GCM_SHA256)
        body.extend_from_slice(&[0x01, 0x00]); // compression methods len=1 type=0

        // Build server_name extension body if snis is non-empty.
        let mut ext_section = Vec::new();
        if !snis.is_empty() {
            // server_name extension.
            let mut sn_body = Vec::new();
            // Build inner list: each entry is name_type(1) + name_len(2) + name.
            let mut inner = Vec::new();
            for s in snis {
                inner.push(0u8); // host_name type
                inner.extend_from_slice(&(s.len() as u16).to_be_bytes());
                inner.extend_from_slice(s.as_bytes());
            }
            sn_body.extend_from_slice(&(inner.len() as u16).to_be_bytes()); // server_name_list length
            sn_body.extend_from_slice(&inner);

            // Wrap as extension: type(2) + length(2) + body.
            ext_section.extend_from_slice(&[0x00, 0x00]); // type 0 = server_name
            ext_section.extend_from_slice(&(sn_body.len() as u16).to_be_bytes());
            ext_section.extend_from_slice(&sn_body);
        }
        body.extend_from_slice(&(ext_section.len() as u16).to_be_bytes());
        body.extend_from_slice(&ext_section);

        // ---- Handshake header ----
        let mut hs = Vec::new();
        hs.push(1); // ClientHello
        let body_len_bytes = (body.len() as u32).to_be_bytes();
        hs.extend_from_slice(&body_len_bytes[1..]); // 24-bit length
        hs.extend_from_slice(&body);

        // ---- Record header ----
        let mut rec = Vec::new();
        rec.push(22); // handshake
        rec.extend_from_slice(&[0x03, 0x01]); // legacy_record_version TLS 1.0 (per RFC 8446 §5.1)
        rec.extend_from_slice(&(hs.len() as u16).to_be_bytes());
        rec.extend_from_slice(&hs);
        rec
    }

    #[test]
    fn extracts_well_formed_sni() {
        let bytes = build_client_hello(&["api.example.com"]);
        let sni = extract_sni(&bytes).unwrap();
        assert_eq!(sni.as_deref(), Some("api.example.com"));
    }

    #[test]
    fn no_sni_returns_ok_none() {
        let bytes = build_client_hello(&[]);
        assert_eq!(extract_sni(&bytes), Ok(None));
    }

    #[test]
    fn malformed_too_short_record_header() {
        let bytes = vec![22, 0x03];
        assert_eq!(extract_sni(&bytes), Err(SniParseError::TooShort));
    }

    #[test]
    fn non_handshake_record_type() {
        let bytes = vec![23, 0x03, 0x03, 0x00, 0x10, 0xff, 0xff];
        assert_eq!(extract_sni(&bytes), Err(SniParseError::NotHandshake));
    }

    #[test]
    fn oversized_record_length_rejected() {
        // length field 0xFFFF > 16384.
        let bytes = vec![22, 0x03, 0x03, 0xff, 0xff, 0x01];
        assert_eq!(extract_sni(&bytes), Err(SniParseError::RecordTooLarge));
    }

    #[test]
    fn truncated_random_is_truncation_error() {
        // Build a valid header but truncate inside random.
        let mut bytes = build_client_hello(&["api.example.com"]);
        // Truncate to 5 (record header) + 4 (handshake header) + 2 (legacy_version) + 10 = 21
        bytes.truncate(21);
        assert!(matches!(
            extract_sni(&bytes),
            Err(SniParseError::Truncated) | Err(SniParseError::InvalidLength)
        ));
    }

    #[test]
    fn multiple_sni_returns_first() {
        let bytes = build_client_hello(&["first.example.com", "second.example.com"]);
        let sni = extract_sni(&bytes).unwrap();
        assert_eq!(sni.as_deref(), Some("first.example.com"));
    }

    #[test]
    fn ipv4_literal_sni_parses_but_is_caller_concern() {
        // Some clients put IP literals in SNI. We parse it; the caller's
        // allowlist matcher will reject IPs because allowlist entries are
        // FQDNs/wildcards.
        let bytes = build_client_hello(&["192.0.2.1"]);
        let sni = extract_sni(&bytes).unwrap();
        assert_eq!(sni.as_deref(), Some("192.0.2.1"));
    }

    #[test]
    fn empty_sni_string_returns_ok_none() {
        let bytes = build_client_hello(&[""]);
        assert_eq!(extract_sni(&bytes), Ok(None));
    }

    #[test]
    fn trailing_dot_is_stripped() {
        let bytes = build_client_hello(&["api.example.com."]);
        let sni = extract_sni(&bytes).unwrap();
        assert_eq!(sni.as_deref(), Some("api.example.com"));
    }

    #[test]
    fn uppercase_sni_is_lowercased() {
        let bytes = build_client_hello(&["API.Example.COM"]);
        let sni = extract_sni(&bytes).unwrap();
        assert_eq!(sni.as_deref(), Some("api.example.com"));
    }

    #[test]
    fn tls13_record_version_accepted() {
        // build_client_hello uses TLS 1.0 record version per RFC 8446 §5.1
        // — here we explicitly construct one with TLS 1.3 inner version to
        // confirm we don't gate on legacy_version.
        let mut bytes = build_client_hello(&["modern.example.com"]);
        // Patch legacy_version inside ClientHello body to 0x0304 (TLS 1.3).
        // Layout: record(5) + hs_type(1) + hs_len(3) + body[0..2] = legacy_version
        bytes[9] = 0x03;
        bytes[10] = 0x04;
        let sni = extract_sni(&bytes).unwrap();
        assert_eq!(sni.as_deref(), Some("modern.example.com"));
    }

    #[test]
    fn handshake_type_must_be_one() {
        let mut bytes = build_client_hello(&["api.example.com"]);
        bytes[5] = 2; // ServerHello
        assert_eq!(extract_sni(&bytes), Err(SniParseError::NotClientHello));
    }

    #[test]
    fn odd_cipher_suites_length_rejected() {
        // Patch cs_len to odd.
        let mut bytes = build_client_hello(&["api.example.com"]);
        // body offset = record(5) + hs(4) + legacy_version(2) + random(32) + sid_len(1) = 44
        // cs_len lives at byte 44..46.
        bytes[44] = 0x00;
        bytes[45] = 0x03;
        assert_eq!(extract_sni(&bytes), Err(SniParseError::InvalidLength));
    }
}