witmproxy 0.0.2-alpha

A WASM-in-the-middle proxy
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
use std::net::SocketAddr;
use std::sync::Arc;

use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Notify, RwLock};
use tracing::{debug, error, info, warn};

use crate::cert::CertificateAuthority;
use crate::config::TransparentProxyConfig;
use crate::events::Event;
use crate::events::connect::Connect;
use crate::plugins::registry::PluginRegistry;
use crate::proxy::tenant_resolver::TenantResolver;
use crate::proxy::{UpstreamClient, is_closed, parse_authority_host_port, run_tls_mitm};
use crate::tenant::TenantContext;

use super::netfilter::NetfilterManager;

/// Transparent proxy server that accepts raw TCP connections redirected by iptables.
pub struct TransparentProxy {
    listen_addr: Option<SocketAddr>,
    ca: Arc<CertificateAuthority>,
    plugin_registry: Option<Arc<RwLock<PluginRegistry>>>,
    tenant_resolver: Arc<dyn TenantResolver>,
    upstream: UpstreamClient,
    config: TransparentProxyConfig,
    shutdown_notify: Arc<Notify>,
    netfilter: Option<NetfilterManager>,
}

impl TransparentProxy {
    pub fn new(
        ca: Arc<CertificateAuthority>,
        plugin_registry: Option<Arc<RwLock<PluginRegistry>>>,
        tenant_resolver: Arc<dyn TenantResolver>,
        upstream: UpstreamClient,
        config: TransparentProxyConfig,
        shutdown_notify: Arc<Notify>,
    ) -> Self {
        Self {
            listen_addr: None,
            ca,
            plugin_registry,
            tenant_resolver,
            upstream,
            config,
            shutdown_notify,
            netfilter: None,
        }
    }

    pub fn listen_addr(&self) -> Option<SocketAddr> {
        self.listen_addr
    }

    pub async fn start(&mut self) -> anyhow::Result<()> {
        let bind_addr: SocketAddr = self
            .config
            .listen_addr
            .as_deref()
            .unwrap_or("0.0.0.0:8080")
            .parse()
            .map_err(|e| anyhow::anyhow!("Invalid transparent proxy bind address: {}", e))?;

        let listener = TcpListener::bind(bind_addr).await?;
        self.listen_addr = Some(listener.local_addr()?);
        info!(
            "Transparent proxy listening on {}",
            self.listen_addr.unwrap()
        );

        // Set up iptables rules if configured
        if self.config.auto_iptables {
            let interface = self
                .config
                .interface
                .clone()
                .unwrap_or_else(|| "tailscale0".to_string());
            let port = self.listen_addr.unwrap().port();
            let mut nf = NetfilterManager::new(interface, port);
            if let Err(e) = nf.setup() {
                warn!("Failed to set up iptables rules: {}", e);
            }
            self.netfilter = Some(nf);
        }

        let shutdown = self.shutdown_notify.clone();
        let ca = self.ca.clone();
        let plugin_registry = self.plugin_registry.clone();
        let tenant_resolver = self.tenant_resolver.clone();
        let upstream = self.upstream.clone();

        tokio::spawn(async move {
            loop {
                tokio::select! {
                    _ = shutdown.notified() => break,
                    accept_result = listener.accept() => {
                        match accept_result {
                            Ok((stream, peer)) => {
                                info!("Transparent: accepted connection from {}", peer);
                                let ca = ca.clone();
                                let plugin_registry = plugin_registry.clone();
                                let tenant_resolver = tenant_resolver.clone();
                                let upstream = upstream.clone();

                                tokio::spawn(async move {
                                    let tenant_ctx = tenant_resolver.resolve(&peer).await;
                                    if let Err(e) = handle_transparent_connection(
                                        stream,
                                        peer,
                                        ca,
                                        plugin_registry,
                                        upstream,
                                        tenant_ctx,
                                    ).await
                                        && !is_closed(&e) {
                                            debug!("Transparent connection error from {}: {}", peer, e);
                                        }
                                });
                            }
                            Err(e) => error!("Transparent accept error: {}", e),
                        }
                    }
                }
            }
        });

        Ok(())
    }
}

/// Extract SNI (Server Name Indication) from a TLS ClientHello by peeking at the stream.
/// Returns the hostname if found, or None if SNI cannot be determined.
pub fn extract_sni_from_client_hello(buf: &[u8]) -> Option<String> {
    // TLS record: type (1) + version (2) + length (2) + data
    if buf.len() < 5 {
        return None;
    }
    // Record type 22 = Handshake
    if buf[0] != 22 {
        return None;
    }

    let record_len = ((buf[3] as usize) << 8) | (buf[4] as usize);
    let handshake = &buf[5..];
    if handshake.len() < record_len.min(handshake.len()) {
        // Partial read is OK, we just need the SNI extension
    }

    // Handshake: type (1) + length (3) + ...
    if handshake.is_empty() || handshake[0] != 1 {
        // Type 1 = ClientHello
        return None;
    }
    if handshake.len() < 4 {
        return None;
    }
    let ch = &handshake[4..];

    // ClientHello: version (2) + random (32) + session_id (1+N) + cipher_suites (2+N) + compression (1+N) + extensions
    if ch.len() < 34 {
        return None;
    }
    let mut pos = 34; // skip version + random

    // Session ID
    if pos >= ch.len() {
        return None;
    }
    let sid_len = ch[pos] as usize;
    pos += 1 + sid_len;

    // Cipher suites
    if pos + 2 > ch.len() {
        return None;
    }
    let cs_len = ((ch[pos] as usize) << 8) | (ch[pos + 1] as usize);
    pos += 2 + cs_len;

    // Compression methods
    if pos >= ch.len() {
        return None;
    }
    let cm_len = ch[pos] as usize;
    pos += 1 + cm_len;

    // Extensions
    if pos + 2 > ch.len() {
        return None;
    }
    let ext_len = ((ch[pos] as usize) << 8) | (ch[pos + 1] as usize);
    pos += 2;

    let ext_end = pos + ext_len.min(ch.len() - pos);
    while pos + 4 <= ext_end {
        let ext_type = ((ch[pos] as u16) << 8) | (ch[pos + 1] as u16);
        let ext_data_len = ((ch[pos + 2] as usize) << 8) | (ch[pos + 3] as usize);
        pos += 4;

        if ext_type == 0 {
            // SNI extension
            if pos + ext_data_len > ext_end {
                return None;
            }
            let sni_data = &ch[pos..pos + ext_data_len];
            // SNI list: total_len (2) + entries
            if sni_data.len() < 2 {
                return None;
            }
            let mut sni_pos = 2; // skip total length
            while sni_pos + 3 <= sni_data.len() {
                let name_type = sni_data[sni_pos];
                let name_len =
                    ((sni_data[sni_pos + 1] as usize) << 8) | (sni_data[sni_pos + 2] as usize);
                sni_pos += 3;
                if name_type == 0 && sni_pos + name_len <= sni_data.len() {
                    // Host name
                    return String::from_utf8(sni_data[sni_pos..sni_pos + name_len].to_vec()).ok();
                }
                sni_pos += name_len;
            }
            return None;
        }

        pos += ext_data_len;
    }

    None
}

/// Check if any plugin wants to handle a connection to the given host.
async fn should_intercept(
    plugin_registry: &Option<Arc<RwLock<PluginRegistry>>>,
    hostname: &str,
) -> bool {
    let Some(registry) = plugin_registry else {
        return false;
    };

    let (host, port) = match parse_authority_host_port(hostname, 443) {
        Ok(hp) => hp,
        Err(_) => (hostname.to_string(), 443),
    };

    let connect_event: Box<dyn Event> = Box::new(Connect::new(host, port));
    let registry = registry.read().await;
    registry.can_handle(&*connect_event)
}

/// Handle a single transparent connection. Peeks to determine if it's TLS or plain HTTP.
/// For TLS connections where plugins match (via Connect event on SNI hostname),
/// delegates to the shared `run_tls_mitm` pipeline. Otherwise forwards raw TCP.
async fn handle_transparent_connection(
    mut stream: TcpStream,
    peer: SocketAddr,
    ca: Arc<CertificateAuthority>,
    plugin_registry: Option<Arc<RwLock<PluginRegistry>>>,
    upstream: UpstreamClient,
    _tenant_ctx: TenantContext,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
    // Peek at the first bytes to determine protocol
    let mut peek_buf = [0u8; 5];
    let n = stream.peek(&mut peek_buf).await?;
    if n == 0 {
        return Ok(());
    }

    if peek_buf[0] == 22 {
        // TLS ClientHello -- read enough to extract SNI
        let mut hello_buf = vec![0u8; 4096];
        let n = stream.peek(&mut hello_buf).await?;
        let hello_data = &hello_buf[..n];

        let hostname = extract_sni_from_client_hello(hello_data).unwrap_or_else(|| {
            warn!("Could not extract SNI from ClientHello from {}", peer);
            "unknown".to_string()
        });

        info!("Transparent TLS: SNI={} from {}", hostname, peer);

        if should_intercept(&plugin_registry, &hostname).await {
            // Plugin(s) want this connection — run the full MITM pipeline
            info!("Transparent: intercepting {} (plugins matched)", hostname);
            let authority = format!("{}:443", hostname);
            if let Err(e) = run_tls_mitm(upstream, stream, authority, ca, plugin_registry).await
                && !is_closed(&e) {
                    debug!("Transparent MITM error for {}: {}", hostname, e);
                }
        } else {
            // No plugins care — raw TCP forward to the real server
            info!(
                "Transparent: forwarding {} directly (no plugins matched)",
                hostname
            );
            let mut upstream_stream = TcpStream::connect(format!("{}:443", hostname)).await?;
            match tokio::io::copy_bidirectional(&mut stream, &mut upstream_stream).await {
                Ok(_) => {}
                Err(e) if is_closed(&e) => {}
                Err(e) => debug!("Transparent forward error for {}: {}", hostname, e),
            }
        }
    } else {
        // Plain HTTP — raw TCP forward (port 80 traffic, no MITM needed)
        info!("Transparent HTTP: forwarding from {}", peer);
        // Peek to extract Host header for upstream connection
        let mut buf = vec![0u8; 8192];
        let n = stream.peek(&mut buf).await?;
        let request_data = std::str::from_utf8(&buf[..n]).unwrap_or("");
        let host = request_data
            .lines()
            .find(|l| l.to_lowercase().starts_with("host:"))
            .and_then(|l| l.split_once(':').map(|(_, v)| v.trim().to_string()))
            .unwrap_or_default();

        if host.is_empty() {
            debug!("Transparent HTTP: no Host header found, dropping");
            return Ok(());
        }

        let mut upstream_stream = TcpStream::connect(format!("{}:80", host)).await?;
        match tokio::io::copy_bidirectional(&mut stream, &mut upstream_stream).await {
            Ok(_) => {}
            Err(e) if is_closed(&e) => {}
            Err(e) => debug!("Transparent HTTP forward error for {}: {}", host, e),
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_sni_from_real_client_hello() {
        // A minimal TLS 1.2 ClientHello with SNI "example.com"
        let hello = build_test_client_hello("example.com");
        let sni = extract_sni_from_client_hello(&hello);
        assert_eq!(sni.as_deref(), Some("example.com"));
    }

    #[test]
    fn test_extract_sni_no_sni_extension() {
        // Minimal ClientHello without any extensions
        let hello = build_test_client_hello_no_sni();
        let sni = extract_sni_from_client_hello(&hello);
        assert!(sni.is_none());
    }

    #[test]
    fn test_extract_sni_not_tls() {
        let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
        let sni = extract_sni_from_client_hello(buf);
        assert!(sni.is_none());
    }

    #[test]
    fn test_extract_sni_empty() {
        let sni = extract_sni_from_client_hello(&[]);
        assert!(sni.is_none());
    }

    /// Build a minimal TLS ClientHello with SNI extension for testing.
    fn build_test_client_hello(hostname: &str) -> Vec<u8> {
        let hostname_bytes = hostname.as_bytes();
        let sni_name_len = hostname_bytes.len();

        // SNI extension data: list_len(2) + type(1) + name_len(2) + name
        let sni_entry_len = 1 + 2 + sni_name_len; // type + len + name
        let sni_list_len = sni_entry_len;
        let sni_ext_data_len = 2 + sni_list_len; // list_len field + entries

        // Extension: type(2) + len(2) + data
        let ext_total = 4 + sni_ext_data_len;

        // ClientHello body: version(2) + random(32) + session_id_len(1) + cipher_suites_len(2) + cipher(2) + compression_len(1) + compression(1) + extensions_len(2) + extensions
        let ch_body_len = 2 + 32 + 1 + 2 + 2 + 1 + 1 + 2 + ext_total;

        // Handshake: type(1) + len(3) + body
        let hs_len = 1 + 3 + ch_body_len;

        // TLS record: type(1) + version(2) + len(2) + handshake
        let mut buf = Vec::with_capacity(5 + hs_len);

        // TLS record header
        buf.push(22); // handshake
        buf.push(3);
        buf.push(1); // TLS 1.0
        buf.push((hs_len >> 8) as u8);
        buf.push((hs_len & 0xff) as u8);

        // Handshake header
        buf.push(1); // ClientHello
        buf.push(0);
        buf.push((ch_body_len >> 8) as u8);
        buf.push((ch_body_len & 0xff) as u8);

        // ClientHello body
        buf.push(3);
        buf.push(3); // TLS 1.2
        buf.extend_from_slice(&[0u8; 32]); // random

        buf.push(0); // session_id length

        buf.push(0);
        buf.push(2); // cipher suites length
        buf.push(0x00);
        buf.push(0xff); // one cipher suite

        buf.push(1); // compression methods length
        buf.push(0); // null compression

        // Extensions length
        buf.push((ext_total >> 8) as u8);
        buf.push((ext_total & 0xff) as u8);

        // SNI extension
        buf.push(0);
        buf.push(0); // extension type = SNI
        buf.push((sni_ext_data_len >> 8) as u8);
        buf.push((sni_ext_data_len & 0xff) as u8);

        // SNI list
        buf.push((sni_list_len >> 8) as u8);
        buf.push((sni_list_len & 0xff) as u8);

        buf.push(0); // host_name type
        buf.push((sni_name_len >> 8) as u8);
        buf.push((sni_name_len & 0xff) as u8);
        buf.extend_from_slice(hostname_bytes);

        buf
    }

    fn build_test_client_hello_no_sni() -> Vec<u8> {
        // ClientHello body without extensions
        let ch_body_len = 2 + 32 + 1 + 2 + 2 + 1 + 1;
        let hs_len = 1 + 3 + ch_body_len;

        let mut buf = Vec::with_capacity(5 + hs_len);

        buf.push(22);
        buf.push(3);
        buf.push(1);
        buf.push((hs_len >> 8) as u8);
        buf.push((hs_len & 0xff) as u8);

        buf.push(1);
        buf.push(0);
        buf.push((ch_body_len >> 8) as u8);
        buf.push((ch_body_len & 0xff) as u8);

        buf.push(3);
        buf.push(3);
        buf.extend_from_slice(&[0u8; 32]);
        buf.push(0);
        buf.push(0);
        buf.push(2);
        buf.push(0x00);
        buf.push(0xff);
        buf.push(1);
        buf.push(0);

        buf
    }
}