1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use futures_util::StreamExt;
7use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
8use hyper::body::{Bytes, Frame};
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tower::Service;
12use tower_http::services::{ServeDir, ServeFile};
13
14use crate::compression;
15use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
16use crate::request::{resolve_trusted_ip, Request};
17use crate::response::{Response, ResponseBodyType, ResponseTransport};
18use crate::worker::{spawn_eval_thread, PipelineResult};
19
20type BoxError = Box<dyn std::error::Error + Send + Sync>;
21type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
22
23pub async fn handle<B>(
24 engine: Arc<ArcSwap<crate::Engine>>,
25 addr: Option<SocketAddr>,
26 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
27 req: hyper::Request<B>,
28) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
29where
30 B: hyper::body::Body + Unpin + Send + 'static,
31 B::Data: Into<Bytes> + Clone + Send,
32 B::Error: Into<BoxError> + Send,
33{
34 let engine = engine.load_full();
36 match handle_inner(engine, addr, trusted_proxies, req).await {
37 Ok(response) => Ok(response),
38 Err(err) => {
39 eprintln!("Error handling request: {err}");
40 let response = hyper::Response::builder().status(500).body(
41 Full::new(format!("Script error: {err}").into())
42 .map_err(|never| match never {})
43 .boxed(),
44 )?;
45 Ok(response)
46 }
47 }
48}
49
50async fn handle_inner<B>(
51 engine: Arc<crate::Engine>,
52 addr: Option<SocketAddr>,
53 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
54 req: hyper::Request<B>,
55) -> HTTPResult
56where
57 B: hyper::body::Body + Unpin + Send + 'static,
58 B::Data: Into<Bytes> + Clone + Send,
59 B::Error: Into<BoxError> + Send,
60{
61 let (parts, mut body) = req.into_parts();
62
63 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
65
66 tokio::task::spawn(async move {
68 while let Some(frame) = body.frame().await {
69 match frame {
70 Ok(frame) => {
71 if let Some(data) = frame.data_ref() {
72 let bytes: Bytes = (*data).clone().into();
73 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
74 break;
75 }
76 }
77 }
78 Err(err) => {
79 let _ = body_tx.send(Err(err.into())).await;
80 break;
81 }
82 }
83 }
84 });
85
86 let stream = nu_protocol::ByteStream::from_fn(
88 nu_protocol::Span::unknown(),
89 engine.state.signals().clone(),
90 nu_protocol::ByteStreamType::Unknown,
91 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
92 Some(Ok(bytes)) => {
93 buffer.extend_from_slice(&bytes);
94 Ok(true)
95 }
96 Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
97 error: "Body read error".into(),
98 msg: err.to_string(),
99 span: None,
100 help: None,
101 inner: vec![],
102 }),
103 None => Ok(false),
104 },
105 );
106
107 let start_time = Instant::now();
109 let request_id = scru128::new();
110 let guard = RequestGuard::new(request_id);
111
112 let remote_ip = addr.as_ref().map(|a| a.ip());
113 let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &trusted_proxies);
114
115 let request = Request {
116 proto: format!("{:?}", parts.version),
117 method: parts.method.clone(),
118 authority: parts.uri.authority().map(|a| a.to_string()),
119 remote_ip,
120 remote_port: addr.as_ref().map(|a| a.port()),
121 trusted_ip,
122 headers: parts.headers.clone(),
123 uri: parts.uri.clone(),
124 path: parts.uri.path().to_string(),
125 query: parts
126 .uri
127 .query()
128 .map(|v| {
129 url::form_urlencoded::parse(v.as_bytes())
130 .into_owned()
131 .collect()
132 })
133 .unwrap_or_else(std::collections::HashMap::new),
134 };
135
136 log_request(request_id, &request);
138
139 let reload_token = engine.reload_token.clone();
140 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
141
142 let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
146 tokio::join!(async { meta_rx.await.ok() }, async {
147 bridged_body.await.map_err(|e| e.into())
148 });
149
150 let use_brotli = compression::accepts_brotli(&parts.headers);
151
152 match special_response.as_ref().map(|r| &r.body_type) {
154 Some(ResponseBodyType::Normal) | None => {
155 build_normal_response(body_result?, use_brotli, guard, start_time, reload_token).await
157 }
158 Some(ResponseBodyType::Static {
159 root,
160 path,
161 fallback,
162 }) => {
163 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
164 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
165 *static_req.method_mut() = parts.method.clone();
166 *static_req.headers_mut() = parts.headers.clone();
167
168 let res = if let Some(fallback) = fallback {
169 let fp = root.join(fallback);
170 ServeDir::new(root)
171 .fallback(ServeFile::new(fp))
172 .call(static_req)
173 .await?
174 } else {
175 ServeDir::new(root).call(static_req).await?
176 };
177 let (res_parts, body) = res.into_parts();
178 log_response(
179 request_id,
180 res_parts.status.as_u16(),
181 &res_parts.headers,
182 start_time,
183 );
184
185 let bytes = body.collect().await?.to_bytes();
186 let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
187 let logging_body = LoggingBody::new(inner_body, guard);
188 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
189 Ok(res)
190 }
191 Some(ResponseBodyType::ReverseProxy {
192 target_url,
193 headers,
194 preserve_host,
195 strip_prefix,
196 request_body,
197 query,
198 }) => {
199 let body = Full::new(Bytes::from(request_body.clone()));
200 let mut proxy_req = hyper::Request::new(body);
201
202 let path = if let Some(prefix) = strip_prefix {
204 parts
205 .uri
206 .path()
207 .strip_prefix(prefix)
208 .unwrap_or(parts.uri.path())
209 } else {
210 parts.uri.path()
211 };
212
213 let target_uri = {
215 let query_string = if let Some(custom_query) = query {
216 url::form_urlencoded::Serializer::new(String::new())
218 .extend_pairs(custom_query.iter())
219 .finish()
220 } else if let Some(orig_query) = parts.uri.query() {
221 orig_query.to_string()
223 } else {
224 String::new()
225 };
226
227 if query_string.is_empty() {
228 format!("{target_url}{path}")
229 } else {
230 format!("{target_url}{path}?{query_string}")
231 }
232 };
233
234 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
235 *proxy_req.method_mut() = parts.method.clone();
236
237 let mut header_map = parts.headers.clone();
239
240 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
242 header_map.insert(
243 hyper::header::CONTENT_LENGTH,
244 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
245 );
246 }
247
248 for (k, v) in headers {
250 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
251
252 match v {
253 crate::response::HeaderValue::Single(s) => {
254 let header_value = hyper::header::HeaderValue::from_str(s)?;
255 header_map.insert(header_name, header_value);
256 }
257 crate::response::HeaderValue::Multiple(values) => {
258 for value in values {
259 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
260 header_map.append(header_name.clone(), header_value);
261 }
262 }
263 }
264 }
265 }
266
267 if !preserve_host {
269 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
270 if let Some(authority) = target_uri.authority() {
271 header_map.insert(
272 hyper::header::HOST,
273 hyper::header::HeaderValue::from_str(authority.as_ref())?,
274 );
275 }
276 }
277 }
278
279 *proxy_req.headers_mut() = header_map;
280
281 let client =
283 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
284 .build_http();
285
286 match client.request(proxy_req).await {
287 Ok(response) => {
288 let (res_parts, body) = response.into_parts();
289 log_response(
290 request_id,
291 res_parts.status.as_u16(),
292 &res_parts.headers,
293 start_time,
294 );
295
296 let inner_body = body.map_err(|e| e.into()).boxed();
297 let logging_body = LoggingBody::new(inner_body, guard);
298 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
299 Ok(res)
300 }
301 Err(_e) => {
302 let empty_headers = hyper::header::HeaderMap::new();
303 log_response(request_id, 502, &empty_headers, start_time);
304
305 let inner_body = Full::new("Bad Gateway".into())
306 .map_err(|never| match never {})
307 .boxed();
308 let logging_body = LoggingBody::new(inner_body, guard);
309 let response = hyper::Response::builder()
310 .status(502)
311 .body(logging_body.boxed())?;
312 Ok(response)
313 }
314 }
315 }
316 }
317}
318
319async fn build_normal_response(
320 pipeline_result: PipelineResult,
321 use_brotli: bool,
322 guard: RequestGuard,
323 start_time: Instant,
324 reload_token: CancellationToken,
325) -> HTTPResult {
326 let request_id = guard.request_id();
327 let (inferred_content_type, http_meta, body) = pipeline_result;
328 let status = match (http_meta.status, &body) {
329 (Some(s), _) => s,
330 (None, ResponseTransport::Empty) => 204,
331 (None, _) => 200,
332 };
333 let mut builder = hyper::Response::builder().status(status);
334 let mut header_map = hyper::header::HeaderMap::new();
335
336 let content_type = http_meta
342 .headers
343 .get("content-type")
344 .or(http_meta.headers.get("Content-Type"))
345 .and_then(|hv| match hv {
346 crate::response::HeaderValue::Single(s) => Some(s.clone()),
347 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
348 })
349 .or(inferred_content_type)
350 .or_else(|| {
351 if matches!(body, ResponseTransport::Empty) {
352 None
353 } else {
354 Some("text/html; charset=utf-8".to_string())
355 }
356 });
357
358 if let Some(ref ct) = content_type {
359 header_map.insert(
360 hyper::header::CONTENT_TYPE,
361 hyper::header::HeaderValue::from_str(ct)?,
362 );
363 }
364
365 if use_brotli {
367 header_map.insert(
368 hyper::header::CONTENT_ENCODING,
369 hyper::header::HeaderValue::from_static("br"),
370 );
371 header_map.insert(
372 hyper::header::VARY,
373 hyper::header::HeaderValue::from_static("accept-encoding"),
374 );
375 }
376
377 let is_sse = content_type.as_deref() == Some("text/event-stream");
379 if is_sse {
380 header_map.insert(
381 hyper::header::CACHE_CONTROL,
382 hyper::header::HeaderValue::from_static("no-cache"),
383 );
384 header_map.insert(
385 hyper::header::CONNECTION,
386 hyper::header::HeaderValue::from_static("keep-alive"),
387 );
388 }
389
390 for (k, v) in &http_meta.headers {
391 if k.to_lowercase() != "content-type" {
392 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
393
394 match v {
395 crate::response::HeaderValue::Single(s) => {
396 let header_value = hyper::header::HeaderValue::from_str(s)?;
397 header_map.insert(header_name, header_value);
398 }
399 crate::response::HeaderValue::Multiple(values) => {
400 for value in values {
401 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
402 header_map.append(header_name.clone(), header_value);
403 }
404 }
405 }
406 }
407 }
408 }
409
410 log_response(request_id, status, &header_map, start_time);
411 *builder.headers_mut().unwrap() = header_map;
412
413 let inner_body = match body {
414 ResponseTransport::Empty => Empty::<Bytes>::new()
415 .map_err(|never| match never {})
416 .boxed(),
417 ResponseTransport::Full(bytes) => {
418 if use_brotli {
419 let compressed = compression::compress_full(&bytes)?;
420 Full::new(Bytes::from(compressed))
421 .map_err(|never| match never {})
422 .boxed()
423 } else {
424 Full::new(bytes.into())
425 .map_err(|never| match never {})
426 .boxed()
427 }
428 }
429 ResponseTransport::Stream(rx) => {
430 if use_brotli {
431 compression::compress_stream(rx)
432 } else if is_sse {
433 let stream = futures_util::stream::try_unfold(
435 (ReceiverStream::new(rx), reload_token),
436 |(mut data_rx, token)| async move {
437 tokio::select! {
438 biased;
439 _ = token.cancelled() => {
440 Err(std::io::Error::other("reload").into())
441 }
442 item = StreamExt::next(&mut data_rx) => {
443 match item {
444 Some(data) => Ok(Some((Frame::data(Bytes::from(data)), (data_rx, token)))),
445 None => Ok(None),
446 }
447 }
448 }
449 },
450 );
451 BodyExt::boxed(StreamBody::new(stream))
452 } else {
453 let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
454 BodyExt::boxed(StreamBody::new(stream))
455 }
456 }
457 };
458
459 let logging_body = LoggingBody::new(inner_body, guard);
461 Ok(builder.body(logging_body.boxed())?)
462}