Skip to main content

lingxia_proxy/
server.rs

1use crate::error::ProxyError;
2use crate::router::UpstreamConfig;
3use crate::router::{ProxyRouter, RouteDecision};
4use crate::upstream::connect_upstream;
5use http::Uri;
6use log::{debug, warn};
7use std::net::SocketAddr;
8use std::sync::{Arc, RwLock};
9#[cfg(feature = "capture")]
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::{TcpListener, TcpStream};
13
14#[cfg(feature = "capture")]
15use {
16    crate::capture::{CapturedSession, handle_http_sessions},
17    crate::mitm::CaConfig,
18    tokio::sync::broadcast,
19};
20
21/// A local HTTP CONNECT proxy.
22///
23/// ```text
24///  WebView  ──CONNECT──▶  LocalProxy  ──(router)──▶  Direct / SOCKS5 upstream
25/// ```
26///
27/// Clone is cheap (Arc-backed).  Pass one clone to `tokio::spawn(proxy.run())`,
28/// keep the other for `set_router()` / `set_capture_ca()` calls.
29#[derive(Clone)]
30pub struct LocalProxy {
31    inner: Arc<Inner>,
32}
33
34struct Inner {
35    router: RwLock<Arc<dyn ProxyRouter>>,
36    local_addr: SocketAddr,
37    listener: TcpListener,
38
39    #[cfg(feature = "capture")]
40    ca: RwLock<Option<Arc<CaConfig>>>,
41    #[cfg(feature = "capture")]
42    session_tx: broadcast::Sender<Arc<CapturedSession>>,
43}
44
45impl LocalProxy {
46    /// Bind on `addr` (use `"127.0.0.1:0"` to let the OS pick a free port).
47    pub async fn bind(
48        addr: &str,
49        initial_router: Arc<dyn ProxyRouter>,
50    ) -> Result<Self, ProxyError> {
51        let listener = TcpListener::bind(addr).await?;
52        let local_addr = listener.local_addr()?;
53
54        #[cfg(feature = "capture")]
55        let (session_tx, _) = broadcast::channel(512);
56
57        Ok(Self {
58            inner: Arc::new(Inner {
59                router: RwLock::new(initial_router),
60                local_addr,
61                listener,
62                #[cfg(feature = "capture")]
63                ca: RwLock::new(None),
64                #[cfg(feature = "capture")]
65                session_tx,
66            }),
67        })
68    }
69
70    pub fn local_addr(&self) -> SocketAddr {
71        self.inner.local_addr
72    }
73
74    /// Swap the routing policy.  Takes effect for all new connections.
75    pub fn set_router(&self, router: Arc<dyn ProxyRouter>) {
76        *self.inner.router.write().unwrap() = router;
77    }
78
79    /// Install the CA certificate and key used for HTTPS MITM.
80    ///
81    /// Call this before subscribing with `session_receiver()`.
82    /// The CA cert must already be installed in the system/browser trust store
83    /// by the caller — the proxy does not install it automatically.
84    ///
85    /// See README for how to generate and install the CA with mkcert or openssl.
86    #[cfg(feature = "capture")]
87    pub fn set_capture_ca(&self, ca: CaConfig) {
88        *self.inner.ca.write().unwrap() = Some(Arc::new(ca));
89    }
90
91    /// Subscribe to structured HTTP sessions from all captured tunnels.
92    ///
93    /// HTTPS tunnels are MITM-intercepted only when a CA has been set via
94    /// `set_capture_ca()`.  Without a CA, HTTPS traffic is forwarded opaquely.
95    ///
96    /// Filter in the consumer — it's one line, no proxy-side overhead:
97    /// ```ignore
98    /// while let Ok(s) = rx.recv().await {
99    ///     if !s.host.ends_with("openai.com") { continue; }
100    ///     agent.process(serde_json::to_value(&*s)?).await;
101    /// }
102    /// ```
103    #[cfg(feature = "capture")]
104    pub fn session_receiver(&self) -> broadcast::Receiver<Arc<CapturedSession>> {
105        self.inner.session_tx.subscribe()
106    }
107
108    /// Drive the accept loop.  Blocks until the listener errors.
109    pub async fn run(&self) {
110        loop {
111            match self.inner.listener.accept().await {
112                Ok((stream, peer)) => {
113                    debug!("proxy: accepted {peer}");
114                    let proxy = self.clone();
115                    tokio::spawn(async move {
116                        if let Err(e) = proxy.handle_connection(stream).await {
117                            debug!("proxy: {peer} — {e}");
118                        }
119                    });
120                }
121                Err(e) => {
122                    warn!("proxy: accept error: {e}");
123                    break;
124                }
125            }
126        }
127    }
128
129    // ── per-connection ────────────────────────────────────────────────────
130
131    async fn handle_connection(&self, mut stream: TcpStream) -> Result<(), ProxyError> {
132        match read_proxy_request(&mut stream).await? {
133            ProxyRequest::Connect { host, port } => {
134                let upstream_cfg = self.route_upstream(&host, port)?;
135
136                #[cfg(feature = "capture")]
137                {
138                    return self
139                        .handle_connect_with_capture(stream, host, port, upstream_cfg)
140                        .await;
141                }
142
143                #[cfg(not(feature = "capture"))]
144                {
145                    let mut up = connect_upstream(&upstream_cfg, &host, port).await?;
146                    stream
147                        .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
148                        .await?;
149                    tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
150                    Ok(())
151                }
152            }
153            ProxyRequest::ForwardHttp {
154                host,
155                port,
156                initial_bytes,
157            } => {
158                let upstream_cfg = self.route_upstream(&host, port)?;
159
160                #[cfg(feature = "capture")]
161                {
162                    return self
163                        .handle_forward_http_with_capture(
164                            stream,
165                            initial_bytes,
166                            host,
167                            port,
168                            upstream_cfg,
169                        )
170                        .await;
171                }
172
173                #[cfg(not(feature = "capture"))]
174                {
175                    let mut up = connect_upstream(&upstream_cfg, &host, port).await?;
176                    up.write_all(&initial_bytes).await?;
177                    tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
178                    Ok(())
179                }
180            }
181        }
182    }
183
184    fn route_upstream(&self, host: &str, port: u16) -> Result<UpstreamConfig, ProxyError> {
185        let router = self.inner.router.read().unwrap();
186        match router.route(host, port)? {
187            RouteDecision::Upstream(cfg) => Ok(cfg),
188            RouteDecision::Block => Err(ProxyError::UpstreamConnect(format!(
189                "{host}:{port} blocked by policy"
190            ))),
191        }
192    }
193
194    #[cfg(feature = "capture")]
195    async fn handle_connect_with_capture(
196        &self,
197        mut stream: TcpStream,
198        host: String,
199        port: u16,
200        upstream_cfg: UpstreamConfig,
201    ) -> Result<(), ProxyError> {
202        let tx = self.inner.session_tx.clone();
203        let has_consumer = tx.receiver_count() > 0;
204        let ca = self.inner.ca.read().unwrap().clone(); // Option<Arc<CaConfig>>
205
206        // HTTPS with active subscribers AND a CA loaded → MITM.
207        if port == 443 && has_consumer && ca.is_some() {
208            let up_stream = connect_upstream(&upstream_cfg, &host, port).await?;
209            stream
210                .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
211                .await?;
212
213            let (client_tls, server_tls) =
214                crate::mitm::intercept(stream, &host, up_stream, ca.as_deref().unwrap()).await?;
215
216            handle_http_sessions(host, port, true, client_tls, server_tls, tx)
217                .await
218                .map_err(ProxyError::Io)?;
219
220            return Ok(());
221        }
222
223        // Plain HTTP with active subscribers → parse directly (no TLS needed).
224        if port != 443 && has_consumer {
225            let up_stream = connect_upstream(&upstream_cfg, &host, port).await?;
226            stream
227                .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
228                .await?;
229
230            handle_http_sessions(host, port, false, stream, up_stream, tx)
231                .await
232                .map_err(ProxyError::Io)?;
233
234            return Ok(());
235        }
236
237        // No subscribers, or HTTPS without CA → plain tunnel, zero overhead.
238        let mut up = connect_upstream(&upstream_cfg, &host, port).await?;
239        stream
240            .write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
241            .await?;
242        tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
243        Ok(())
244    }
245
246    #[cfg(feature = "capture")]
247    async fn handle_forward_http_with_capture(
248        &self,
249        stream: TcpStream,
250        initial_bytes: Vec<u8>,
251        host: String,
252        port: u16,
253        upstream_cfg: UpstreamConfig,
254    ) -> Result<(), ProxyError> {
255        let tx = self.inner.session_tx.clone();
256        let has_consumer = tx.receiver_count() > 0;
257
258        let up_stream = connect_upstream(&upstream_cfg, &host, port).await?;
259        if has_consumer {
260            let client = PrefixedIo::new(stream, initial_bytes);
261            handle_http_sessions(host, port, false, client, up_stream, tx)
262                .await
263                .map_err(ProxyError::Io)?;
264            return Ok(());
265        }
266
267        let mut stream = stream;
268        let mut up = up_stream;
269        up.write_all(&initial_bytes).await?;
270        tokio::io::copy_bidirectional(&mut stream, &mut up).await?;
271        Ok(())
272    }
273}
274
275enum ProxyRequest {
276    Connect {
277        host: String,
278        port: u16,
279    },
280    ForwardHttp {
281        host: String,
282        port: u16,
283        initial_bytes: Vec<u8>,
284    },
285}
286
287// ── HTTP proxy request parser ──────────────────────────────────────────────
288
289async fn read_proxy_request(stream: &mut TcpStream) -> Result<ProxyRequest, ProxyError> {
290    let mut buf = Vec::with_capacity(512);
291    let mut tmp = [0u8; 1];
292
293    loop {
294        stream.read_exact(&mut tmp).await?;
295        buf.push(tmp[0]);
296        if buf.ends_with(b"\r\n\r\n") {
297            break;
298        }
299        if buf.len() > 8192 {
300            return Err(ProxyError::BadRequest("CONNECT headers too large".into()));
301        }
302    }
303
304    let text =
305        std::str::from_utf8(&buf).map_err(|_| ProxyError::BadRequest("Non-UTF8 CONNECT".into()))?;
306
307    let first_line = text
308        .lines()
309        .next()
310        .ok_or_else(|| ProxyError::BadRequest("Empty request".into()))?;
311
312    let mut parts = first_line.split_whitespace();
313    let method = parts
314        .next()
315        .ok_or_else(|| ProxyError::BadRequest("No method".into()))?;
316    let target = parts
317        .next()
318        .ok_or_else(|| ProxyError::BadRequest("No request target".into()))?;
319    let version = parts
320        .next()
321        .ok_or_else(|| ProxyError::BadRequest("No HTTP version".into()))?;
322
323    if method.eq_ignore_ascii_case("CONNECT") {
324        let (host, port) = parse_authority_host_port(target, None)?;
325        return Ok(ProxyRequest::Connect { host, port });
326    }
327
328    let (host, port, upstream_target) = parse_forward_target(target, text)?;
329    let first_line_end = buf
330        .windows(2)
331        .position(|window| window == b"\r\n")
332        .ok_or_else(|| ProxyError::BadRequest("Malformed request line".into()))?;
333    let mut initial_bytes = format!("{method} {upstream_target} {version}\r\n").into_bytes();
334    initial_bytes.extend_from_slice(&buf[first_line_end + 2..]);
335    Ok(ProxyRequest::ForwardHttp {
336        host,
337        port,
338        initial_bytes,
339    })
340}
341
342fn parse_forward_target(
343    target: &str,
344    request_text: &str,
345) -> Result<(String, u16, String), ProxyError> {
346    if target.starts_with("http://") || target.starts_with("https://") {
347        let uri: Uri = target
348            .parse()
349            .map_err(|e| ProxyError::BadRequest(format!("Bad absolute-form URI: {e}")))?;
350        let scheme = uri
351            .scheme_str()
352            .ok_or_else(|| ProxyError::BadRequest("Absolute-form URI missing scheme".into()))?;
353        if scheme.eq_ignore_ascii_case("https") {
354            return Err(ProxyError::BadRequest(
355                "HTTPS absolute-form requests must use CONNECT".into(),
356            ));
357        }
358        let host = uri
359            .host()
360            .ok_or_else(|| ProxyError::BadRequest("Absolute-form URI missing host".into()))?
361            .to_string();
362        let port = uri.port_u16().unwrap_or(80);
363        let path = uri
364            .path_and_query()
365            .map(|pq| pq.as_str().to_string())
366            .unwrap_or_else(|| "/".to_string());
367        return Ok((host, port, path));
368    }
369
370    let authority = request_text
371        .lines()
372        .skip(1)
373        .find_map(|line| {
374            let (name, value) = line.split_once(':')?;
375            name.trim()
376                .eq_ignore_ascii_case("host")
377                .then(|| value.trim().to_string())
378        })
379        .ok_or_else(|| ProxyError::BadRequest("HTTP proxy request missing Host header".into()))?;
380    let (host, port) = parse_authority_host_port(&authority, Some(80))?;
381    Ok((host, port, target.to_string()))
382}
383
384fn parse_authority_host_port(
385    authority: &str,
386    default_port: Option<u16>,
387) -> Result<(String, u16), ProxyError> {
388    let authority = authority.trim();
389    if authority.is_empty() {
390        return Err(ProxyError::BadRequest("Empty authority".into()));
391    }
392    if let Some(host) = authority.strip_prefix('[')
393        && let Some((host, rest)) = host.split_once(']')
394    {
395        let port = if let Some(rest) = rest.strip_prefix(':') {
396            rest.parse()
397                .map_err(|_| ProxyError::BadRequest(format!("Bad port in '{authority}'")))?
398        } else {
399            default_port
400                .ok_or_else(|| ProxyError::BadRequest(format!("No port in '{authority}'")))?
401        };
402        return Ok((host.to_string(), port));
403    }
404
405    if let Some((host, port)) = authority.rsplit_once(':')
406        && !host.is_empty()
407        && let Ok(port) = port.parse()
408    {
409        return Ok((host.to_string(), port));
410    }
411
412    Ok((
413        authority.to_string(),
414        default_port.ok_or_else(|| ProxyError::BadRequest(format!("No port in '{authority}'")))?,
415    ))
416}
417
418#[cfg(feature = "capture")]
419struct PrefixedIo<T> {
420    prefix: Vec<u8>,
421    prefix_pos: usize,
422    inner: T,
423}
424
425#[cfg(feature = "capture")]
426impl<T> PrefixedIo<T> {
427    fn new(inner: T, prefix: Vec<u8>) -> Self {
428        Self {
429            prefix,
430            prefix_pos: 0,
431            inner,
432        }
433    }
434}
435
436#[cfg(feature = "capture")]
437impl<T> AsyncRead for PrefixedIo<T>
438where
439    T: AsyncRead + Unpin,
440{
441    fn poll_read(
442        mut self: std::pin::Pin<&mut Self>,
443        cx: &mut std::task::Context<'_>,
444        buf: &mut ReadBuf<'_>,
445    ) -> std::task::Poll<std::io::Result<()>> {
446        if self.prefix_pos < self.prefix.len() {
447            let remaining = &self.prefix[self.prefix_pos..];
448            let to_copy = remaining.len().min(buf.remaining());
449            buf.put_slice(&remaining[..to_copy]);
450            self.prefix_pos += to_copy;
451            return std::task::Poll::Ready(Ok(()));
452        }
453        std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
454    }
455}
456
457#[cfg(feature = "capture")]
458impl<T> AsyncWrite for PrefixedIo<T>
459where
460    T: AsyncWrite + Unpin,
461{
462    fn poll_write(
463        mut self: std::pin::Pin<&mut Self>,
464        cx: &mut std::task::Context<'_>,
465        buf: &[u8],
466    ) -> std::task::Poll<std::io::Result<usize>> {
467        std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
468    }
469
470    fn poll_flush(
471        mut self: std::pin::Pin<&mut Self>,
472        cx: &mut std::task::Context<'_>,
473    ) -> std::task::Poll<std::io::Result<()>> {
474        std::pin::Pin::new(&mut self.inner).poll_flush(cx)
475    }
476
477    fn poll_shutdown(
478        mut self: std::pin::Pin<&mut Self>,
479        cx: &mut std::task::Context<'_>,
480    ) -> std::task::Poll<std::io::Result<()>> {
481        std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
482    }
483}