1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
7use hyper::body::{Bytes, Frame};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::StreamExt;
10use tower::Service;
11use tower_http::services::{ServeDir, ServeFile};
12
13use crate::compression;
14use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
15use crate::request::{resolve_trusted_ip, Request};
16use crate::response::{Response, ResponseBodyType, ResponseTransport};
17use crate::worker::spawn_eval_thread;
18
19type BoxError = Box<dyn std::error::Error + Send + Sync>;
20type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
21
22pub async fn handle<B>(
23 engine: Arc<ArcSwap<crate::Engine>>,
24 addr: Option<SocketAddr>,
25 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
26 req: hyper::Request<B>,
27) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
28where
29 B: hyper::body::Body + Unpin + Send + 'static,
30 B::Data: Into<Bytes> + Clone + Send,
31 B::Error: Into<BoxError> + Send,
32{
33 let engine = engine.load_full();
35 match handle_inner(engine, addr, trusted_proxies, req).await {
36 Ok(response) => Ok(response),
37 Err(err) => {
38 eprintln!("Error handling request: {err}");
39 let response = hyper::Response::builder().status(500).body(
40 Full::new(format!("Script error: {err}").into())
41 .map_err(|never| match never {})
42 .boxed(),
43 )?;
44 Ok(response)
45 }
46 }
47}
48
49async fn handle_inner<B>(
50 engine: Arc<crate::Engine>,
51 addr: Option<SocketAddr>,
52 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
53 req: hyper::Request<B>,
54) -> HTTPResult
55where
56 B: hyper::body::Body + Unpin + Send + 'static,
57 B::Data: Into<Bytes> + Clone + Send,
58 B::Error: Into<BoxError> + Send,
59{
60 let (parts, mut body) = req.into_parts();
61
62 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
64
65 tokio::task::spawn(async move {
67 while let Some(frame) = body.frame().await {
68 match frame {
69 Ok(frame) => {
70 if let Some(data) = frame.data_ref() {
71 let bytes: Bytes = (*data).clone().into();
72 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
73 break;
74 }
75 }
76 }
77 Err(err) => {
78 let _ = body_tx.send(Err(err.into())).await;
79 break;
80 }
81 }
82 }
83 });
84
85 let stream = nu_protocol::ByteStream::from_fn(
87 nu_protocol::Span::unknown(),
88 engine.state.signals().clone(),
89 nu_protocol::ByteStreamType::Unknown,
90 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
91 Some(Ok(bytes)) => {
92 buffer.extend_from_slice(&bytes);
93 Ok(true)
94 }
95 Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
96 error: "Body read error".into(),
97 msg: err.to_string(),
98 span: None,
99 help: None,
100 inner: vec![],
101 }),
102 None => Ok(false),
103 },
104 );
105
106 let start_time = Instant::now();
108 let request_id = scru128::new();
109 let guard = RequestGuard::new(request_id);
110
111 let remote_ip = addr.as_ref().map(|a| a.ip());
112 let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &trusted_proxies);
113
114 let request = Request {
115 proto: format!("{:?}", parts.version),
116 method: parts.method.clone(),
117 authority: parts.uri.authority().map(|a| a.to_string()),
118 remote_ip,
119 remote_port: addr.as_ref().map(|a| a.port()),
120 trusted_ip,
121 headers: parts.headers.clone(),
122 uri: parts.uri.clone(),
123 path: parts.uri.path().to_string(),
124 query: parts
125 .uri
126 .query()
127 .map(|v| {
128 url::form_urlencoded::parse(v.as_bytes())
129 .into_owned()
130 .collect()
131 })
132 .unwrap_or_else(std::collections::HashMap::new),
133 };
134
135 log_request(request_id, &request);
137
138 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
139
140 let (meta, body_result): (
144 Response,
145 Result<(Option<String>, ResponseTransport), BoxError>,
146 ) = tokio::join!(
147 async {
148 meta_rx.await.unwrap_or(Response {
149 status: 200,
150 headers: std::collections::HashMap::new(),
151 body_type: ResponseBodyType::Normal,
152 })
153 },
154 async { bridged_body.await.map_err(|e| e.into()) }
155 );
156
157 let use_brotli = compression::accepts_brotli(&parts.headers);
158
159 match &meta.body_type {
160 ResponseBodyType::Normal => {
161 build_normal_response(&meta, Ok(body_result?), use_brotli, guard, start_time).await
162 }
163 ResponseBodyType::Static {
164 root,
165 path,
166 fallback,
167 } => {
168 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
169 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
170 *static_req.method_mut() = parts.method.clone();
171 *static_req.headers_mut() = parts.headers.clone();
172
173 let res = if let Some(fallback) = fallback {
174 let fp = root.join(fallback);
175 ServeDir::new(root)
176 .fallback(ServeFile::new(fp))
177 .call(static_req)
178 .await?
179 } else {
180 ServeDir::new(root).call(static_req).await?
181 };
182 let (res_parts, body) = res.into_parts();
183 log_response(
184 request_id,
185 res_parts.status.as_u16(),
186 &res_parts.headers,
187 start_time,
188 );
189
190 let bytes = body.collect().await?.to_bytes();
191 let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
192 let logging_body = LoggingBody::new(inner_body, guard);
193 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
194 Ok(res)
195 }
196 ResponseBodyType::ReverseProxy {
197 target_url,
198 headers,
199 preserve_host,
200 strip_prefix,
201 request_body,
202 query,
203 } => {
204 let body = Full::new(Bytes::from(request_body.clone()));
205 let mut proxy_req = hyper::Request::new(body);
206
207 let path = if let Some(prefix) = strip_prefix {
209 parts
210 .uri
211 .path()
212 .strip_prefix(prefix)
213 .unwrap_or(parts.uri.path())
214 } else {
215 parts.uri.path()
216 };
217
218 let target_uri = {
220 let query_string = if let Some(custom_query) = query {
221 url::form_urlencoded::Serializer::new(String::new())
223 .extend_pairs(custom_query.iter())
224 .finish()
225 } else if let Some(orig_query) = parts.uri.query() {
226 orig_query.to_string()
228 } else {
229 String::new()
230 };
231
232 if query_string.is_empty() {
233 format!("{target_url}{path}")
234 } else {
235 format!("{target_url}{path}?{query_string}")
236 }
237 };
238
239 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
240 *proxy_req.method_mut() = parts.method.clone();
241
242 let mut header_map = parts.headers.clone();
244
245 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
247 header_map.insert(
248 hyper::header::CONTENT_LENGTH,
249 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
250 );
251 }
252
253 for (k, v) in headers {
255 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
256
257 match v {
258 crate::response::HeaderValue::Single(s) => {
259 let header_value = hyper::header::HeaderValue::from_str(s)?;
260 header_map.insert(header_name, header_value);
261 }
262 crate::response::HeaderValue::Multiple(values) => {
263 for value in values {
264 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
265 header_map.append(header_name.clone(), header_value);
266 }
267 }
268 }
269 }
270 }
271
272 if !preserve_host {
274 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
275 if let Some(authority) = target_uri.authority() {
276 header_map.insert(
277 hyper::header::HOST,
278 hyper::header::HeaderValue::from_str(authority.as_ref())?,
279 );
280 }
281 }
282 }
283
284 *proxy_req.headers_mut() = header_map;
285
286 let client =
288 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
289 .build_http();
290
291 match client.request(proxy_req).await {
292 Ok(response) => {
293 let (res_parts, body) = response.into_parts();
294 log_response(
295 request_id,
296 res_parts.status.as_u16(),
297 &res_parts.headers,
298 start_time,
299 );
300
301 let inner_body = body.map_err(|e| e.into()).boxed();
302 let logging_body = LoggingBody::new(inner_body, guard);
303 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
304 Ok(res)
305 }
306 Err(_e) => {
307 let empty_headers = hyper::header::HeaderMap::new();
308 log_response(request_id, 502, &empty_headers, start_time);
309
310 let inner_body = Full::new("Bad Gateway".into())
311 .map_err(|never| match never {})
312 .boxed();
313 let logging_body = LoggingBody::new(inner_body, guard);
314 let response = hyper::Response::builder()
315 .status(502)
316 .body(logging_body.boxed())?;
317 Ok(response)
318 }
319 }
320 }
321 }
322}
323
324async fn build_normal_response(
325 meta: &Response,
326 body_result: Result<(Option<String>, ResponseTransport), BoxError>,
327 use_brotli: bool,
328 guard: RequestGuard,
329 start_time: Instant,
330) -> HTTPResult {
331 let request_id = guard.request_id();
332 let (inferred_content_type, body) = body_result?;
333 let mut builder = hyper::Response::builder().status(meta.status);
334 let mut header_map = hyper::header::HeaderMap::new();
335
336 let content_type = meta
337 .headers
338 .get("content-type")
339 .or(meta.headers.get("Content-Type"))
340 .and_then(|hv| match hv {
341 crate::response::HeaderValue::Single(s) => Some(s.clone()),
342 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
343 })
344 .or(inferred_content_type)
345 .unwrap_or("text/html; charset=utf-8".to_string());
346
347 header_map.insert(
348 hyper::header::CONTENT_TYPE,
349 hyper::header::HeaderValue::from_str(&content_type)?,
350 );
351
352 if use_brotli {
354 header_map.insert(
355 hyper::header::CONTENT_ENCODING,
356 hyper::header::HeaderValue::from_static("br"),
357 );
358 header_map.insert(
359 hyper::header::VARY,
360 hyper::header::HeaderValue::from_static("accept-encoding"),
361 );
362 }
363
364 if content_type == "text/event-stream" {
366 header_map.insert(
367 hyper::header::CACHE_CONTROL,
368 hyper::header::HeaderValue::from_static("no-cache"),
369 );
370 header_map.insert(
371 hyper::header::CONNECTION,
372 hyper::header::HeaderValue::from_static("keep-alive"),
373 );
374 }
375
376 for (k, v) in &meta.headers {
377 if k.to_lowercase() != "content-type" {
378 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
379
380 match v {
381 crate::response::HeaderValue::Single(s) => {
382 let header_value = hyper::header::HeaderValue::from_str(s)?;
383 header_map.insert(header_name, header_value);
384 }
385 crate::response::HeaderValue::Multiple(values) => {
386 for value in values {
387 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
388 header_map.append(header_name.clone(), header_value);
389 }
390 }
391 }
392 }
393 }
394 }
395
396 log_response(request_id, meta.status, &header_map, start_time);
397 *builder.headers_mut().unwrap() = header_map;
398
399 let inner_body = match body {
400 ResponseTransport::Empty => Empty::<Bytes>::new()
401 .map_err(|never| match never {})
402 .boxed(),
403 ResponseTransport::Full(bytes) => {
404 if use_brotli {
405 let compressed = compression::compress_full(&bytes)?;
406 Full::new(Bytes::from(compressed))
407 .map_err(|never| match never {})
408 .boxed()
409 } else {
410 Full::new(bytes.into())
411 .map_err(|never| match never {})
412 .boxed()
413 }
414 }
415 ResponseTransport::Stream(rx) => {
416 if use_brotli {
417 compression::compress_stream(rx)
418 } else {
419 let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
420 StreamBody::new(stream).boxed()
421 }
422 }
423 };
424
425 let logging_body = LoggingBody::new(inner_body, guard);
427 Ok(builder.body(logging_body.boxed())?)
428}