1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Instant;
5
6use arc_swap::ArcSwap;
7use futures_util::{Stream, StreamExt};
8use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
9use hyper::body::{Bytes, Frame};
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_util::sync::CancellationToken;
12use tower::Service;
13use tower_http::services::{ServeDir, ServeFile};
14
15use nu_protocol::shell_error::generic::GenericError;
16
17use crate::compression;
18use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
19use crate::request::{resolve_trusted_ip, Request};
20use crate::response::{Response, ResponseBodyType, ResponseTransport};
21use crate::worker::{spawn_eval_thread, PipelineResult};
22
23type BoxError = Box<dyn std::error::Error + Send + Sync>;
24type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
25
26const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.8.js";
27const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js");
28const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js.br");
29
30pub struct AppConfig {
31 pub trusted_proxies: Vec<ipnet::IpNet>,
32 pub datastar: bool,
33 pub dev: bool,
34}
35
36pub async fn handle<B>(
37 engine: Arc<ArcSwap<crate::Engine>>,
38 addr: Option<SocketAddr>,
39 config: Arc<AppConfig>,
40 req: hyper::Request<B>,
41) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
42where
43 B: hyper::body::Body + Unpin + Send + 'static,
44 B::Data: Into<Bytes> + Clone + Send,
45 B::Error: Into<BoxError> + Send,
46{
47 let engine = engine.load_full();
49 match handle_inner(engine, addr, config, req).await {
50 Ok(response) => Ok(response),
51 Err(err) => {
52 eprintln!("Error handling request: {err}");
53 let response = hyper::Response::builder().status(500).body(
54 Full::new(format!("Script error: {err}").into())
55 .map_err(|never| match never {})
56 .boxed(),
57 )?;
58 Ok(response)
59 }
60 }
61}
62
63async fn handle_inner<B>(
64 engine: Arc<crate::Engine>,
65 addr: Option<SocketAddr>,
66 config: Arc<AppConfig>,
67 req: hyper::Request<B>,
68) -> HTTPResult
69where
70 B: hyper::body::Body + Unpin + Send + 'static,
71 B::Data: Into<Bytes> + Clone + Send,
72 B::Error: Into<BoxError> + Send,
73{
74 let (parts, mut body) = req.into_parts();
75
76 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
78
79 tokio::task::spawn(async move {
81 while let Some(frame) = body.frame().await {
82 match frame {
83 Ok(frame) => {
84 if let Some(data) = frame.data_ref() {
85 let bytes: Bytes = (*data).clone().into();
86 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
87 break;
88 }
89 }
90 }
91 Err(err) => {
92 let _ = body_tx.send(Err(err.into())).await;
93 break;
94 }
95 }
96 }
97 });
98
99 let stream = nu_protocol::ByteStream::from_fn(
101 nu_protocol::Span::unknown(),
102 engine.state.signals().clone(),
103 nu_protocol::ByteStreamType::Unknown,
104 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
105 Some(Ok(bytes)) => {
106 buffer.extend_from_slice(&bytes);
107 Ok(true)
108 }
109 Some(Err(err)) => Err(nu_protocol::ShellError::Generic(
110 GenericError::new_internal("Body read error", err.to_string()),
111 )),
112 None => Ok(false),
113 },
114 );
115
116 let start_time = Instant::now();
118 let request_id = scru128::new();
119 let guard = RequestGuard::new(request_id);
120
121 let remote_ip = addr.as_ref().map(|a| a.ip());
122 let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
123
124 let request = Request {
125 proto: format!("{:?}", parts.version),
126 method: parts.method.clone(),
127 authority: parts.uri.authority().map(|a| a.to_string()),
128 remote_ip,
129 remote_port: addr.as_ref().map(|a| a.port()),
130 trusted_ip,
131 headers: parts.headers.clone(),
132 uri: parts.uri.clone(),
133 path: parts.uri.path().to_string(),
134 query: parts
135 .uri
136 .query()
137 .map(|v| {
138 url::form_urlencoded::parse(v.as_bytes())
139 .into_owned()
140 .collect()
141 })
142 .unwrap_or_else(std::collections::HashMap::new),
143 };
144
145 log_request(request_id, &request);
147
148 if config.datastar && request.path == DATASTAR_JS_PATH {
150 let use_brotli = compression::accepts_brotli(&parts.headers);
151 let mut header_map = hyper::header::HeaderMap::new();
152 header_map.insert(
153 hyper::header::CONTENT_TYPE,
154 hyper::header::HeaderValue::from_static("application/javascript"),
155 );
156 header_map.insert(
157 hyper::header::CACHE_CONTROL,
158 hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
159 );
160 let body = if use_brotli {
161 header_map.insert(
162 hyper::header::CONTENT_ENCODING,
163 hyper::header::HeaderValue::from_static("br"),
164 );
165 header_map.insert(
166 hyper::header::VARY,
167 hyper::header::HeaderValue::from_static("accept-encoding"),
168 );
169 Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
170 .map_err(|never| match never {})
171 .boxed()
172 } else {
173 Full::new(Bytes::from_static(DATASTAR_JS))
174 .map_err(|never| match never {})
175 .boxed()
176 };
177 log_response(request_id, 200, &header_map, start_time);
178 let logging_body = LoggingBody::new(body, guard);
179 let mut response = hyper::Response::builder()
180 .status(200)
181 .body(logging_body.boxed())?;
182 *response.headers_mut() = header_map;
183 return Ok(response);
184 }
185
186 let sse_cancel_token = engine.sse_cancel_token.clone();
187 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
188
189 let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
193 tokio::join!(async { meta_rx.await.ok() }, async {
194 bridged_body.await.map_err(|e| e.into())
195 });
196
197 let use_brotli = compression::accepts_brotli(&parts.headers);
198
199 match special_response.as_ref().map(|r| &r.body_type) {
201 Some(ResponseBodyType::Normal) | None => {
202 build_normal_response(
204 body_result?,
205 use_brotli,
206 guard,
207 start_time,
208 sse_cancel_token,
209 )
210 .await
211 }
212 Some(ResponseBodyType::Static {
213 root,
214 path,
215 fallback,
216 }) => {
217 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
218 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
219 *static_req.method_mut() = parts.method.clone();
220 *static_req.headers_mut() = parts.headers.clone();
221
222 let res = if let Some(fallback) = fallback {
223 let fp = root.join(fallback);
224 ServeDir::new(root)
225 .fallback(ServeFile::new(fp))
226 .call(static_req)
227 .await?
228 } else {
229 ServeDir::new(root).call(static_req).await?
230 };
231 let (res_parts, body) = res.into_parts();
232 log_response(
233 request_id,
234 res_parts.status.as_u16(),
235 &res_parts.headers,
236 start_time,
237 );
238
239 let bytes = body.collect().await?.to_bytes();
240 let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
241 let logging_body = LoggingBody::new(inner_body, guard);
242 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
243 Ok(res)
244 }
245 Some(ResponseBodyType::ReverseProxy {
246 target_url,
247 headers,
248 preserve_host,
249 strip_prefix,
250 request_body,
251 query,
252 }) => {
253 let body = Full::new(Bytes::from(request_body.clone()));
254 let mut proxy_req = hyper::Request::new(body);
255
256 let path = if let Some(prefix) = strip_prefix {
258 parts
259 .uri
260 .path()
261 .strip_prefix(prefix)
262 .unwrap_or(parts.uri.path())
263 } else {
264 parts.uri.path()
265 };
266
267 let path = if path.starts_with('/') {
271 path.to_string()
272 } else {
273 format!("/{path}")
274 };
275
276 let target_uri = {
278 let query_string = if let Some(custom_query) = query {
279 url::form_urlencoded::Serializer::new(String::new())
281 .extend_pairs(custom_query.iter())
282 .finish()
283 } else if let Some(orig_query) = parts.uri.query() {
284 orig_query.to_string()
286 } else {
287 String::new()
288 };
289
290 if query_string.is_empty() {
291 format!("{target_url}{path}")
292 } else {
293 format!("{target_url}{path}?{query_string}")
294 }
295 };
296
297 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
298 *proxy_req.method_mut() = parts.method.clone();
299
300 let mut header_map = parts.headers.clone();
302
303 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
305 header_map.insert(
306 hyper::header::CONTENT_LENGTH,
307 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
308 );
309 }
310
311 for (k, v) in headers {
313 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
314
315 match v {
316 crate::response::HeaderValue::Single(s) => {
317 let header_value = hyper::header::HeaderValue::from_str(s)?;
318 header_map.insert(header_name, header_value);
319 }
320 crate::response::HeaderValue::Multiple(values) => {
321 for value in values {
322 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
323 header_map.append(header_name.clone(), header_value);
324 }
325 }
326 }
327 }
328 }
329
330 if !preserve_host {
332 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
333 if let Some(authority) = target_uri.authority() {
334 header_map.insert(
335 hyper::header::HOST,
336 hyper::header::HeaderValue::from_str(authority.as_ref())?,
337 );
338 }
339 }
340 }
341
342 *proxy_req.headers_mut() = header_map;
343
344 let client =
346 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
347 .build_http();
348
349 match client.request(proxy_req).await {
350 Ok(response) => {
351 let (res_parts, body) = response.into_parts();
352 log_response(
353 request_id,
354 res_parts.status.as_u16(),
355 &res_parts.headers,
356 start_time,
357 );
358
359 let inner_body = body.map_err(|e| e.into()).boxed();
360 let logging_body = LoggingBody::new(inner_body, guard);
361 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
362 Ok(res)
363 }
364 Err(_e) => {
365 let empty_headers = hyper::header::HeaderMap::new();
366 log_response(request_id, 502, &empty_headers, start_time);
367
368 let inner_body = Full::new("Bad Gateway".into())
369 .map_err(|never| match never {})
370 .boxed();
371 let logging_body = LoggingBody::new(inner_body, guard);
372 let response = hyper::Response::builder()
373 .status(502)
374 .body(logging_body.boxed())?;
375 Ok(response)
376 }
377 }
378 }
379 }
380}
381
382async fn build_normal_response(
383 pipeline_result: PipelineResult,
384 use_brotli: bool,
385 guard: RequestGuard,
386 start_time: Instant,
387 sse_cancel_token: CancellationToken,
388) -> HTTPResult {
389 let request_id = guard.request_id();
390 let (inferred_content_type, http_meta, body) = pipeline_result;
391 let status = match (http_meta.status, &body) {
392 (Some(s), _) => s,
393 (None, ResponseTransport::Empty) => 204,
394 (None, _) => 200,
395 };
396 let mut builder = hyper::Response::builder().status(status);
397 let mut header_map = hyper::header::HeaderMap::new();
398
399 let content_type = http_meta
405 .headers
406 .get("content-type")
407 .or(http_meta.headers.get("Content-Type"))
408 .and_then(|hv| match hv {
409 crate::response::HeaderValue::Single(s) => Some(s.clone()),
410 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
411 })
412 .or(inferred_content_type)
413 .or_else(|| {
414 if matches!(body, ResponseTransport::Empty) {
415 None
416 } else {
417 Some("text/html; charset=utf-8".to_string())
418 }
419 });
420
421 if let Some(ref ct) = content_type {
422 header_map.insert(
423 hyper::header::CONTENT_TYPE,
424 hyper::header::HeaderValue::from_str(ct)?,
425 );
426 }
427
428 if use_brotli {
430 header_map.insert(
431 hyper::header::CONTENT_ENCODING,
432 hyper::header::HeaderValue::from_static("br"),
433 );
434 header_map.insert(
435 hyper::header::VARY,
436 hyper::header::HeaderValue::from_static("accept-encoding"),
437 );
438 }
439
440 let is_sse = content_type.as_deref() == Some("text/event-stream");
442 if is_sse {
443 header_map.insert(
444 hyper::header::CACHE_CONTROL,
445 hyper::header::HeaderValue::from_static("no-cache"),
446 );
447 header_map.insert(
448 hyper::header::CONNECTION,
449 hyper::header::HeaderValue::from_static("keep-alive"),
450 );
451 }
452
453 for (k, v) in &http_meta.headers {
454 if k.to_lowercase() != "content-type" {
455 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
456
457 match v {
458 crate::response::HeaderValue::Single(s) => {
459 let header_value = hyper::header::HeaderValue::from_str(s)?;
460 header_map.insert(header_name, header_value);
461 }
462 crate::response::HeaderValue::Multiple(values) => {
463 for value in values {
464 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
465 header_map.append(header_name.clone(), header_value);
466 }
467 }
468 }
469 }
470 }
471 }
472
473 log_response(request_id, status, &header_map, start_time);
474 *builder.headers_mut().unwrap() = header_map;
475
476 let inner_body = match body {
477 ResponseTransport::Empty => Empty::<Bytes>::new()
478 .map_err(|never| match never {})
479 .boxed(),
480 ResponseTransport::Full(bytes) => {
481 if use_brotli {
482 let compressed = compression::compress_full(&bytes)?;
483 Full::new(Bytes::from(compressed))
484 .map_err(|never| match never {})
485 .boxed()
486 } else {
487 Full::new(bytes.into())
488 .map_err(|never| match never {})
489 .boxed()
490 }
491 }
492 ResponseTransport::Stream(rx) => {
493 let byte_stream: Pin<Box<dyn Stream<Item = Vec<u8>> + Send + Sync>> = if is_sse {
495 Box::pin(futures_util::stream::unfold(
497 (ReceiverStream::new(rx), sse_cancel_token),
498 |(mut data_rx, token)| async move {
499 tokio::select! {
500 biased;
501 _ = token.cancelled() => None,
502 item = StreamExt::next(&mut data_rx) => {
503 item.map(|data| (data, (data_rx, token)))
504 }
505 }
506 },
507 ))
508 } else {
509 Box::pin(ReceiverStream::new(rx))
510 };
511
512 if use_brotli {
514 let brotli = compression::BrotliStream::new(byte_stream);
515 BodyExt::boxed(StreamBody::new(brotli))
516 } else {
517 let stream = byte_stream.map(|data| Ok(Frame::data(Bytes::from(data))));
518 BodyExt::boxed(StreamBody::new(stream))
519 }
520 }
521 };
522
523 let logging_body = LoggingBody::new(inner_body, guard);
525 Ok(builder.body(logging_body.boxed())?)
526}