1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
7use hyper::body::{Bytes, Frame};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::StreamExt;
10use tower::Service;
11use tower_http::services::{ServeDir, ServeFile};
12
13use crate::compression;
14use crate::logging::{log_request, log_response, LoggingBody};
15use crate::request::{resolve_trusted_ip, Request};
16use crate::response::{Response, ResponseBodyType, ResponseTransport};
17use crate::worker::spawn_eval_thread;
18
19type BoxError = Box<dyn std::error::Error + Send + Sync>;
20type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
21
22pub async fn handle<B>(
23 engine: Arc<ArcSwap<crate::Engine>>,
24 addr: Option<SocketAddr>,
25 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
26 req: hyper::Request<B>,
27) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
28where
29 B: hyper::body::Body + Unpin + Send + 'static,
30 B::Data: Into<Bytes> + Clone + Send,
31 B::Error: Into<BoxError> + Send,
32{
33 let engine = engine.load_full();
35 match handle_inner(engine, addr, trusted_proxies, req).await {
36 Ok(response) => Ok(response),
37 Err(err) => {
38 eprintln!("Error handling request: {err}");
39 let response = hyper::Response::builder().status(500).body(
40 Full::new(format!("Script error: {err}").into())
41 .map_err(|never| match never {})
42 .boxed(),
43 )?;
44 Ok(response)
45 }
46 }
47}
48
49async fn handle_inner<B>(
50 engine: Arc<crate::Engine>,
51 addr: Option<SocketAddr>,
52 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
53 req: hyper::Request<B>,
54) -> HTTPResult
55where
56 B: hyper::body::Body + Unpin + Send + 'static,
57 B::Data: Into<Bytes> + Clone + Send,
58 B::Error: Into<BoxError> + Send,
59{
60 let (parts, mut body) = req.into_parts();
61
62 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
64
65 tokio::task::spawn(async move {
67 while let Some(frame) = body.frame().await {
68 match frame {
69 Ok(frame) => {
70 if let Some(data) = frame.data_ref() {
71 let bytes: Bytes = (*data).clone().into();
72 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
73 break;
74 }
75 }
76 }
77 Err(err) => {
78 let _ = body_tx.send(Err(err.into())).await;
79 break;
80 }
81 }
82 }
83 });
84
85 let stream = nu_protocol::ByteStream::from_fn(
87 nu_protocol::Span::unknown(),
88 engine.state.signals().clone(),
89 nu_protocol::ByteStreamType::Unknown,
90 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
91 Some(Ok(bytes)) => {
92 buffer.extend_from_slice(&bytes);
93 Ok(true)
94 }
95 Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
96 error: "Body read error".into(),
97 msg: err.to_string(),
98 span: None,
99 help: None,
100 inner: vec![],
101 }),
102 None => Ok(false),
103 },
104 );
105
106 let start_time = Instant::now();
108 let request_id = scru128::new();
109
110 let remote_ip = addr.as_ref().map(|a| a.ip());
111 let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &trusted_proxies);
112
113 let request = Request {
114 proto: format!("{:?}", parts.version),
115 method: parts.method.clone(),
116 authority: parts.uri.authority().map(|a| a.to_string()),
117 remote_ip,
118 remote_port: addr.as_ref().map(|a| a.port()),
119 trusted_ip,
120 headers: parts.headers.clone(),
121 uri: parts.uri.clone(),
122 path: parts.uri.path().to_string(),
123 query: parts
124 .uri
125 .query()
126 .map(|v| {
127 url::form_urlencoded::parse(v.as_bytes())
128 .into_owned()
129 .collect()
130 })
131 .unwrap_or_else(std::collections::HashMap::new),
132 };
133
134 log_request(request_id, &request);
136
137 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
138
139 let (meta, body_result): (
143 Response,
144 Result<(Option<String>, ResponseTransport), BoxError>,
145 ) = tokio::join!(
146 async {
147 meta_rx.await.unwrap_or(Response {
148 status: 200,
149 headers: std::collections::HashMap::new(),
150 body_type: ResponseBodyType::Normal,
151 })
152 },
153 async { bridged_body.await.map_err(|e| e.into()) }
154 );
155
156 let use_brotli = compression::accepts_brotli(&parts.headers);
157
158 match &meta.body_type {
159 ResponseBodyType::Normal => {
160 build_normal_response(&meta, Ok(body_result?), use_brotli, request_id, start_time).await
161 }
162 ResponseBodyType::Static {
163 root,
164 path,
165 fallback,
166 } => {
167 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
168 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
169 *static_req.method_mut() = parts.method.clone();
170 *static_req.headers_mut() = parts.headers.clone();
171
172 let res = if let Some(fallback) = fallback {
173 let fp = root.join(fallback);
174 ServeDir::new(root)
175 .fallback(ServeFile::new(fp))
176 .call(static_req)
177 .await?
178 } else {
179 ServeDir::new(root).call(static_req).await?
180 };
181 let (res_parts, body) = res.into_parts();
182 log_response(
183 request_id,
184 res_parts.status.as_u16(),
185 &res_parts.headers,
186 start_time,
187 );
188
189 let bytes = body.collect().await?.to_bytes();
190 let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
191 let logging_body = LoggingBody::new(inner_body, request_id);
192 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
193 Ok(res)
194 }
195 ResponseBodyType::ReverseProxy {
196 target_url,
197 headers,
198 preserve_host,
199 strip_prefix,
200 request_body,
201 query,
202 } => {
203 let body = Full::new(Bytes::from(request_body.clone()));
204 let mut proxy_req = hyper::Request::new(body);
205
206 let path = if let Some(prefix) = strip_prefix {
208 parts
209 .uri
210 .path()
211 .strip_prefix(prefix)
212 .unwrap_or(parts.uri.path())
213 } else {
214 parts.uri.path()
215 };
216
217 let target_uri = {
219 let query_string = if let Some(custom_query) = query {
220 url::form_urlencoded::Serializer::new(String::new())
222 .extend_pairs(custom_query.iter())
223 .finish()
224 } else if let Some(orig_query) = parts.uri.query() {
225 orig_query.to_string()
227 } else {
228 String::new()
229 };
230
231 if query_string.is_empty() {
232 format!("{target_url}{path}")
233 } else {
234 format!("{target_url}{path}?{query_string}")
235 }
236 };
237
238 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
239 *proxy_req.method_mut() = parts.method.clone();
240
241 let mut header_map = parts.headers.clone();
243
244 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
246 header_map.insert(
247 hyper::header::CONTENT_LENGTH,
248 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
249 );
250 }
251
252 for (k, v) in headers {
254 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
255
256 match v {
257 crate::response::HeaderValue::Single(s) => {
258 let header_value = hyper::header::HeaderValue::from_str(s)?;
259 header_map.insert(header_name, header_value);
260 }
261 crate::response::HeaderValue::Multiple(values) => {
262 for value in values {
263 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
264 header_map.append(header_name.clone(), header_value);
265 }
266 }
267 }
268 }
269 }
270
271 if !preserve_host {
273 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
274 if let Some(authority) = target_uri.authority() {
275 header_map.insert(
276 hyper::header::HOST,
277 hyper::header::HeaderValue::from_str(authority.as_ref())?,
278 );
279 }
280 }
281 }
282
283 *proxy_req.headers_mut() = header_map;
284
285 let client =
287 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
288 .build_http();
289
290 match client.request(proxy_req).await {
291 Ok(response) => {
292 let (res_parts, body) = response.into_parts();
293 log_response(
294 request_id,
295 res_parts.status.as_u16(),
296 &res_parts.headers,
297 start_time,
298 );
299
300 let inner_body = body.map_err(|e| e.into()).boxed();
301 let logging_body = LoggingBody::new(inner_body, request_id);
302 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
303 Ok(res)
304 }
305 Err(_e) => {
306 let empty_headers = hyper::header::HeaderMap::new();
307 log_response(request_id, 502, &empty_headers, start_time);
308
309 let inner_body = Full::new("Bad Gateway".into())
310 .map_err(|never| match never {})
311 .boxed();
312 let logging_body = LoggingBody::new(inner_body, request_id);
313 let response = hyper::Response::builder()
314 .status(502)
315 .body(logging_body.boxed())?;
316 Ok(response)
317 }
318 }
319 }
320 }
321}
322
323async fn build_normal_response(
324 meta: &Response,
325 body_result: Result<(Option<String>, ResponseTransport), BoxError>,
326 use_brotli: bool,
327 request_id: scru128::Scru128Id,
328 start_time: Instant,
329) -> HTTPResult {
330 let (inferred_content_type, body) = body_result?;
331 let mut builder = hyper::Response::builder().status(meta.status);
332 let mut header_map = hyper::header::HeaderMap::new();
333
334 let content_type = meta
335 .headers
336 .get("content-type")
337 .or(meta.headers.get("Content-Type"))
338 .and_then(|hv| match hv {
339 crate::response::HeaderValue::Single(s) => Some(s.clone()),
340 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
341 })
342 .or(inferred_content_type)
343 .unwrap_or("text/html; charset=utf-8".to_string());
344
345 header_map.insert(
346 hyper::header::CONTENT_TYPE,
347 hyper::header::HeaderValue::from_str(&content_type)?,
348 );
349
350 if use_brotli {
352 header_map.insert(
353 hyper::header::CONTENT_ENCODING,
354 hyper::header::HeaderValue::from_static("br"),
355 );
356 header_map.insert(
357 hyper::header::VARY,
358 hyper::header::HeaderValue::from_static("accept-encoding"),
359 );
360 }
361
362 if content_type == "text/event-stream" {
364 header_map.insert(
365 hyper::header::CACHE_CONTROL,
366 hyper::header::HeaderValue::from_static("no-cache"),
367 );
368 header_map.insert(
369 hyper::header::CONNECTION,
370 hyper::header::HeaderValue::from_static("keep-alive"),
371 );
372 }
373
374 for (k, v) in &meta.headers {
375 if k.to_lowercase() != "content-type" {
376 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
377
378 match v {
379 crate::response::HeaderValue::Single(s) => {
380 let header_value = hyper::header::HeaderValue::from_str(s)?;
381 header_map.insert(header_name, header_value);
382 }
383 crate::response::HeaderValue::Multiple(values) => {
384 for value in values {
385 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
386 header_map.append(header_name.clone(), header_value);
387 }
388 }
389 }
390 }
391 }
392 }
393
394 log_response(request_id, meta.status, &header_map, start_time);
395 *builder.headers_mut().unwrap() = header_map;
396
397 let inner_body = match body {
398 ResponseTransport::Empty => Empty::<Bytes>::new()
399 .map_err(|never| match never {})
400 .boxed(),
401 ResponseTransport::Full(bytes) => {
402 if use_brotli {
403 let compressed = compression::compress_full(&bytes)?;
404 Full::new(Bytes::from(compressed))
405 .map_err(|never| match never {})
406 .boxed()
407 } else {
408 Full::new(bytes.into())
409 .map_err(|never| match never {})
410 .boxed()
411 }
412 }
413 ResponseTransport::Stream(rx) => {
414 if use_brotli {
415 compression::compress_stream(rx)
416 } else {
417 let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
418 StreamBody::new(stream).boxed()
419 }
420 }
421 };
422
423 let logging_body = LoggingBody::new(inner_body, request_id);
425 Ok(builder.body(logging_body.boxed())?)
426}