lynx_core/plugins/
http_request_plugin.rs1use std::path::PathBuf;
2use std::str::FromStr;
3use std::sync::Arc;
4use std::{fs, io};
5
6use anyhow::{anyhow, Error, Result};
7use bytes::Bytes;
8use glob_match::glob_match;
9use http::header::{CONNECTION, CONTENT_LENGTH, HOST, PROXY_AUTHORIZATION};
10use http::uri::Scheme;
11use http::Uri;
12use http_body_util::combinators::BoxBody;
13use http_body_util::{BodyExt, StreamBody};
14use hyper::body::Incoming;
15use hyper::{Request, Response};
16use hyper_rustls::HttpsConnectorBuilder;
17use hyper_util::client::legacy::connect::HttpConnector;
18use hyper_util::client::legacy::Client;
19use hyper_util::rt::TokioExecutor;
20use sea_orm::EntityTrait;
21use tokio::io::AsyncWriteExt;
22use tokio::sync::mpsc;
23use tokio_stream::wrappers::ReceiverStream;
24use tracing::{error, trace, warn};
25
26use crate::entities::rule_content::{self, parse_rule_content};
27use crate::proxy_log::body_write_to_file::{req_body_file, res_body_file};
28use crate::schedular::get_req_trace_id;
29use crate::server_context::DB;
30
31pub async fn build_proxy_request(
32 req: Request<Incoming>,
33) -> Result<Request<BoxBody<bytes::Bytes, anyhow::Error>>> {
34 let trace_id = get_req_trace_id(&req);
35
36 let (parts, body) = req.into_parts();
37 let mut body = body
38 .map_err(|e| anyhow!(e).context("http request body box error"))
39 .boxed();
40 let (tx, rx) = mpsc::channel(1024);
41
42 let rec_stream = ReceiverStream::new(rx);
43 let stream: BoxBody<Bytes, Error> = BodyExt::boxed(StreamBody::new(rec_stream));
45
46 tokio::spawn(async move {
47 let mut req_body_file = req_body_file(&trace_id).await;
48
49 while let Some(frame) = body.frame().await {
50 if let Ok(file) = &mut req_body_file {
51 if let Ok(frame) = &frame {
52 if let Some(data) = frame.data_ref() {
53 let res = file.write_all(data).await;
54 if let Err(e) = res {
55 error!("write file res: {:?}", e);
56 }
57 }
58 }
59 }
60 let _ = tx.send(frame).await;
61 }
62 });
63
64 let req_url = url::Url::parse(parts.uri.to_string().as_str()).unwrap();
65 let mut builder = hyper::Request::builder().method(parts.method);
66
67 let db = DB.get().unwrap();
68
69 let rules = rule_content::Entity::find().all(db).await?;
70
71 let mut match_handled = false;
72
73 for rule in rules {
74 trace!("current rule: {:?}", rule);
75 match parse_rule_content(rule.content) {
76 Ok(content) => {
77 let capture_glob_pattern_str = content.capture.uri;
78 let is_match = glob_match(&capture_glob_pattern_str, req_url.as_str());
79 trace!("is match: {}", is_match);
80 trace!("capture_glob_pattern_str: {}", capture_glob_pattern_str);
81 trace!("req_url: {}", req_url.as_str());
82 if is_match {
83 match_handled = true;
84 let pass_proxy_uri = url::Url::parse(&content.handler.proxy_pass);
85
86 match pass_proxy_uri {
87 Ok(pass_proxy_uri) => {
88 let host = pass_proxy_uri.host_str();
89 let port = pass_proxy_uri.port();
90
91 let mut new_uri = req_url.clone();
92 let _ = new_uri.set_scheme(pass_proxy_uri.scheme());
93 let _ = new_uri.set_host(host);
94 let _ = new_uri.set_port(port);
95
96 trace!("new url: {:?}", new_uri);
97
98 if let Ok(new_uri) = Uri::from_str(new_uri.as_str()) {
99 builder = builder.uri(new_uri);
100 } else {
101 warn!("parse pass proxy uri error: {}", new_uri.as_str());
102 }
103 }
104 Err(e) => {
105 warn!("parse pass proxy uri error: {}", e);
106 }
107 }
108 }
109 }
110 Err(e) => {
111 warn!("parse rule content error: {}", e);
112 }
113 }
114 }
115
116 if !match_handled {
117 builder = builder.uri(parts.uri);
118 }
119
120 for (key, value) in parts.headers.into_iter() {
121 if let Some(key) = key {
122 if matches!(
123 &key,
124 &HOST | &CONNECTION | &PROXY_AUTHORIZATION | &CONTENT_LENGTH
125 ) {
126 continue;
127 }
128 builder = builder.header(key, value);
129 }
130 }
131
132 builder.body(stream).map_err(|e| anyhow!(e))
133}
134
135pub async fn build_proxy_response(
136 trace_id: Arc<String>,
137 res: Response<Incoming>,
138) -> Result<Response<BoxBody<bytes::Bytes, anyhow::Error>>> {
139 let (parts, body) = res.into_parts();
140
141 let mut body = body
142 .map_err(|e| anyhow!(e).context("http proxy body box error"))
143 .boxed();
144
145 let (tx, rx) = mpsc::channel(1024);
146
147 let rec_stream = ReceiverStream::new(rx);
148 let stream: BoxBody<Bytes, Error> = BodyExt::boxed(StreamBody::new(rec_stream));
150
151 tokio::spawn(async move {
152 let mut res_body_file = res_body_file(&trace_id).await;
153
154 while let Some(frame) = body.frame().await {
155 if let Ok(file) = &mut res_body_file {
156 if let Ok(frame) = &frame {
157 if let Some(data) = frame.data_ref() {
158 let res = file.write_all(data).await;
159 if let Err(e) = res {
160 error!("write file res: {:?}", e);
161 }
162 }
163 }
164 }
165 let _ = tx.send(frame).await;
166 }
167 });
168 Ok(Response::from_parts(parts, stream))
169}
170
171pub async fn request(req: Request<Incoming>) -> Result<Response<Incoming>> {
172 let client_builder = Client::builder(TokioExecutor::new());
173 trace!("request: {:?}", req);
174 let proxy_req = build_proxy_request(req).await?;
175 trace!("proxy request: {:?}", proxy_req);
176 let proxy_res = if proxy_req.uri().scheme() == Some(&Scheme::HTTPS) {
177 trace!("fetch https request {}", proxy_req.uri());
178 #[cfg(feature = "test")]
179 let connect = get_test_root_ca(proxy_req.uri().host());
180
181 #[cfg(not(feature = "test"))]
182 let connect = HttpsConnectorBuilder::new()
183 .with_webpki_roots()
184 .https_only()
185 .enable_all_versions()
186 .build();
187
188 client_builder.build(connect).request(proxy_req).await
189 } else {
190 trace!("http request");
191 client_builder
192 .build(HttpConnector::new())
193 .request(proxy_req)
194 .await
195 };
196
197 proxy_res.map_err(|e| anyhow!(e))
198}
199
200#[cfg(feature = "test")]
201fn get_test_root_ca(host: Option<&str>) -> hyper_rustls::HttpsConnector<HttpConnector> {
202 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
203
204 use tokio_rustls::rustls::{ClientConfig, RootCertStore};
205
206 fn is_localhost(host: Option<&str>) -> bool {
207 match host {
208 Some(host) => match host {
209 "localhost" => true,
210 _ => match host.parse::<IpAddr>() {
211 Ok(IpAddr::V4(ip)) => ip == Ipv4Addr::LOCALHOST,
212 Ok(IpAddr::V6(ip)) => ip == Ipv6Addr::LOCALHOST,
213 _ => false,
214 },
215 },
216 None => false,
217 }
218 }
219
220 if !is_localhost(host) {
221 return HttpsConnectorBuilder::new()
222 .with_webpki_roots()
223 .https_only()
224 .enable_all_versions()
225 .build();
226 }
227 let connect_builder = HttpsConnectorBuilder::new();
228 let mut ca_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
229 ca_path.push("tests/fixtures/RootCA.crt");
230 let ca_file = fs::File::open(ca_path).unwrap();
231 let mut rd = io::BufReader::new(ca_file);
232
233 let certs = rustls_pemfile::certs(&mut rd)
235 .collect::<Result<Vec<_>, _>>()
236 .unwrap();
237 let mut roots = RootCertStore::empty();
238 roots.add_parsable_certificates(certs);
239 let tls = ClientConfig::builder()
241 .with_root_certificates(roots)
242 .with_no_client_auth();
243 connect_builder
244 .with_tls_config(tls)
245 .https_only()
246 .enable_all_versions()
247 .build()
248}