lynx_core/plugins/
http_request_plugin.rs

1use std::path::PathBuf;
2use std::str::FromStr;
3use std::sync::Arc;
4use std::{fs, io};
5
6use anyhow::{anyhow, Error, Result};
7use bytes::Bytes;
8use glob_match::glob_match;
9use http::header::{CONNECTION, CONTENT_LENGTH, HOST, PROXY_AUTHORIZATION};
10use http::uri::Scheme;
11use http::Uri;
12use http_body_util::combinators::BoxBody;
13use http_body_util::{BodyExt, StreamBody};
14use hyper::body::Incoming;
15use hyper::{Request, Response};
16use hyper_rustls::HttpsConnectorBuilder;
17use hyper_util::client::legacy::connect::HttpConnector;
18use hyper_util::client::legacy::Client;
19use hyper_util::rt::TokioExecutor;
20use sea_orm::EntityTrait;
21use tokio::io::AsyncWriteExt;
22use tokio::sync::mpsc;
23use tokio_stream::wrappers::ReceiverStream;
24use tracing::{error, trace, warn};
25
26use crate::entities::rule_content::{self, parse_rule_content};
27use crate::proxy_log::body_write_to_file::{req_body_file, res_body_file};
28use crate::schedular::get_req_trace_id;
29use crate::server_context::DB;
30
31pub async fn build_proxy_request(
32    req: Request<Incoming>,
33) -> Result<Request<BoxBody<bytes::Bytes, anyhow::Error>>> {
34    let trace_id = get_req_trace_id(&req);
35
36    let (parts, body) = req.into_parts();
37    let mut body = body
38        .map_err(|e| anyhow!(e).context("http request body box error"))
39        .boxed();
40    let (tx, rx) = mpsc::channel(1024);
41
42    let rec_stream = ReceiverStream::new(rx);
43    // let rs = rec_stream.;
44    let stream: BoxBody<Bytes, Error> = BodyExt::boxed(StreamBody::new(rec_stream));
45
46    tokio::spawn(async move {
47        let mut req_body_file = req_body_file(&trace_id).await;
48
49        while let Some(frame) = body.frame().await {
50            if let Ok(file) = &mut req_body_file {
51                if let Ok(frame) = &frame {
52                    if let Some(data) = frame.data_ref() {
53                        let res = file.write_all(data).await;
54                        if let Err(e) = res {
55                            error!("write file res: {:?}", e);
56                        }
57                    }
58                }
59            }
60            let _ = tx.send(frame).await;
61        }
62    });
63
64    let req_url = url::Url::parse(parts.uri.to_string().as_str()).unwrap();
65    let mut builder = hyper::Request::builder().method(parts.method);
66
67    let db = DB.get().unwrap();
68
69    let rules = rule_content::Entity::find().all(db).await?;
70
71    let mut match_handled = false;
72
73    for rule in rules {
74        trace!("current rule: {:?}", rule);
75        match parse_rule_content(rule.content) {
76            Ok(content) => {
77                let capture_glob_pattern_str = content.capture.uri;
78                let is_match = glob_match(&capture_glob_pattern_str, req_url.as_str());
79                trace!("is match: {}", is_match);
80                trace!("capture_glob_pattern_str: {}", capture_glob_pattern_str);
81                trace!("req_url: {}", req_url.as_str());
82                if is_match {
83                    match_handled = true;
84                    let pass_proxy_uri = url::Url::parse(&content.handler.proxy_pass);
85
86                    match pass_proxy_uri {
87                        Ok(pass_proxy_uri) => {
88                            let host = pass_proxy_uri.host_str();
89                            let port = pass_proxy_uri.port();
90
91                            let mut new_uri = req_url.clone();
92                            let _ = new_uri.set_scheme(pass_proxy_uri.scheme());
93                            let _ = new_uri.set_host(host);
94                            let _ = new_uri.set_port(port);
95
96                            trace!("new url: {:?}", new_uri);
97
98                            if let Ok(new_uri) = Uri::from_str(new_uri.as_str()) {
99                                builder = builder.uri(new_uri);
100                            } else {
101                                warn!("parse pass proxy uri error: {}", new_uri.as_str());
102                            }
103                        }
104                        Err(e) => {
105                            warn!("parse pass proxy uri error: {}", e);
106                        }
107                    }
108                }
109            }
110            Err(e) => {
111                warn!("parse rule content error: {}", e);
112            }
113        }
114    }
115
116    if !match_handled {
117        builder = builder.uri(parts.uri);
118    }
119
120    for (key, value) in parts.headers.into_iter() {
121        if let Some(key) = key {
122            if matches!(
123                &key,
124                &HOST | &CONNECTION | &PROXY_AUTHORIZATION | &CONTENT_LENGTH
125            ) {
126                continue;
127            }
128            builder = builder.header(key, value);
129        }
130    }
131
132    builder.body(stream).map_err(|e| anyhow!(e))
133}
134
135pub async fn build_proxy_response(
136    trace_id: Arc<String>,
137    res: Response<Incoming>,
138) -> Result<Response<BoxBody<bytes::Bytes, anyhow::Error>>> {
139    let (parts, body) = res.into_parts();
140
141    let mut body = body
142        .map_err(|e| anyhow!(e).context("http proxy body box error"))
143        .boxed();
144
145    let (tx, rx) = mpsc::channel(1024);
146
147    let rec_stream = ReceiverStream::new(rx);
148    // let rs = rec_stream.;
149    let stream: BoxBody<Bytes, Error> = BodyExt::boxed(StreamBody::new(rec_stream));
150
151    tokio::spawn(async move {
152        let mut res_body_file = res_body_file(&trace_id).await;
153
154        while let Some(frame) = body.frame().await {
155            if let Ok(file) = &mut res_body_file {
156                if let Ok(frame) = &frame {
157                    if let Some(data) = frame.data_ref() {
158                        let res = file.write_all(data).await;
159                        if let Err(e) = res {
160                            error!("write file res: {:?}", e);
161                        }
162                    }
163                }
164            }
165            let _ = tx.send(frame).await;
166        }
167    });
168    Ok(Response::from_parts(parts, stream))
169}
170
171pub async fn request(req: Request<Incoming>) -> Result<Response<Incoming>> {
172    let client_builder = Client::builder(TokioExecutor::new());
173    trace!("request: {:?}", req);
174    let proxy_req = build_proxy_request(req).await?;
175    trace!("proxy request: {:?}", proxy_req);
176    let proxy_res = if proxy_req.uri().scheme() == Some(&Scheme::HTTPS) {
177        trace!("fetch https request {}", proxy_req.uri());
178        #[cfg(feature = "test")]
179        let connect = get_test_root_ca(proxy_req.uri().host());
180
181        #[cfg(not(feature = "test"))]
182        let connect = HttpsConnectorBuilder::new()
183            .with_webpki_roots()
184            .https_only()
185            .enable_all_versions()
186            .build();
187
188        client_builder.build(connect).request(proxy_req).await
189    } else {
190        trace!("http request");
191        client_builder
192            .build(HttpConnector::new())
193            .request(proxy_req)
194            .await
195    };
196
197    proxy_res.map_err(|e| anyhow!(e))
198}
199
200#[cfg(feature = "test")]
201fn get_test_root_ca(host: Option<&str>) -> hyper_rustls::HttpsConnector<HttpConnector> {
202    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
203
204    use tokio_rustls::rustls::{ClientConfig, RootCertStore};
205
206    fn is_localhost(host: Option<&str>) -> bool {
207        match host {
208            Some(host) => match host {
209                "localhost" => true,
210                _ => match host.parse::<IpAddr>() {
211                    Ok(IpAddr::V4(ip)) => ip == Ipv4Addr::LOCALHOST,
212                    Ok(IpAddr::V6(ip)) => ip == Ipv6Addr::LOCALHOST,
213                    _ => false,
214                },
215            },
216            None => false,
217        }
218    }
219
220    if !is_localhost(host) {
221        return HttpsConnectorBuilder::new()
222            .with_webpki_roots()
223            .https_only()
224            .enable_all_versions()
225            .build();
226    }
227    let connect_builder = HttpsConnectorBuilder::new();
228    let mut ca_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
229    ca_path.push("tests/fixtures/RootCA.crt");
230    let ca_file = fs::File::open(ca_path).unwrap();
231    let mut rd = io::BufReader::new(ca_file);
232
233    // Read trust roots
234    let certs = rustls_pemfile::certs(&mut rd)
235        .collect::<Result<Vec<_>, _>>()
236        .unwrap();
237    let mut roots = RootCertStore::empty();
238    roots.add_parsable_certificates(certs);
239    // TLS client config using the custom CA store for lookups
240    let tls = ClientConfig::builder()
241        .with_root_certificates(roots)
242        .with_no_client_auth();
243    connect_builder
244        .with_tls_config(tls)
245        .https_only()
246        .enable_all_versions()
247        .build()
248}