1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Instant;
5
6use arc_swap::ArcSwap;
7use futures_util::{Stream, StreamExt};
8use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
9use hyper::body::{Bytes, Frame};
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_util::sync::CancellationToken;
12use tower::Service;
13use tower_http::services::{ServeDir, ServeFile};
14
15use crate::compression;
16use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
17use crate::request::{resolve_trusted_ip, Request};
18use crate::response::{Response, ResponseBodyType, ResponseTransport};
19use crate::worker::{spawn_eval_thread, PipelineResult};
20
21type BoxError = Box<dyn std::error::Error + Send + Sync>;
22type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
23
24const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.8.js";
25const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js");
26const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js.br");
27
28pub struct AppConfig {
29 pub trusted_proxies: Vec<ipnet::IpNet>,
30 pub datastar: bool,
31 pub dev: bool,
32}
33
34pub async fn handle<B>(
35 engine: Arc<ArcSwap<crate::Engine>>,
36 addr: Option<SocketAddr>,
37 config: Arc<AppConfig>,
38 req: hyper::Request<B>,
39) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
40where
41 B: hyper::body::Body + Unpin + Send + 'static,
42 B::Data: Into<Bytes> + Clone + Send,
43 B::Error: Into<BoxError> + Send,
44{
45 let engine = engine.load_full();
47 match handle_inner(engine, addr, config, req).await {
48 Ok(response) => Ok(response),
49 Err(err) => {
50 eprintln!("Error handling request: {err}");
51 let response = hyper::Response::builder().status(500).body(
52 Full::new(format!("Script error: {err}").into())
53 .map_err(|never| match never {})
54 .boxed(),
55 )?;
56 Ok(response)
57 }
58 }
59}
60
61async fn handle_inner<B>(
62 engine: Arc<crate::Engine>,
63 addr: Option<SocketAddr>,
64 config: Arc<AppConfig>,
65 req: hyper::Request<B>,
66) -> HTTPResult
67where
68 B: hyper::body::Body + Unpin + Send + 'static,
69 B::Data: Into<Bytes> + Clone + Send,
70 B::Error: Into<BoxError> + Send,
71{
72 let (parts, mut body) = req.into_parts();
73
74 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
76
77 tokio::task::spawn(async move {
79 while let Some(frame) = body.frame().await {
80 match frame {
81 Ok(frame) => {
82 if let Some(data) = frame.data_ref() {
83 let bytes: Bytes = (*data).clone().into();
84 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
85 break;
86 }
87 }
88 }
89 Err(err) => {
90 let _ = body_tx.send(Err(err.into())).await;
91 break;
92 }
93 }
94 }
95 });
96
97 let stream = nu_protocol::ByteStream::from_fn(
99 nu_protocol::Span::unknown(),
100 engine.state.signals().clone(),
101 nu_protocol::ByteStreamType::Unknown,
102 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
103 Some(Ok(bytes)) => {
104 buffer.extend_from_slice(&bytes);
105 Ok(true)
106 }
107 Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
108 error: "Body read error".into(),
109 msg: err.to_string(),
110 span: None,
111 help: None,
112 inner: vec![],
113 }),
114 None => Ok(false),
115 },
116 );
117
118 let start_time = Instant::now();
120 let request_id = scru128::new();
121 let guard = RequestGuard::new(request_id);
122
123 let remote_ip = addr.as_ref().map(|a| a.ip());
124 let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
125
126 let request = Request {
127 proto: format!("{:?}", parts.version),
128 method: parts.method.clone(),
129 authority: parts.uri.authority().map(|a| a.to_string()),
130 remote_ip,
131 remote_port: addr.as_ref().map(|a| a.port()),
132 trusted_ip,
133 headers: parts.headers.clone(),
134 uri: parts.uri.clone(),
135 path: parts.uri.path().to_string(),
136 query: parts
137 .uri
138 .query()
139 .map(|v| {
140 url::form_urlencoded::parse(v.as_bytes())
141 .into_owned()
142 .collect()
143 })
144 .unwrap_or_else(std::collections::HashMap::new),
145 };
146
147 log_request(request_id, &request);
149
150 if config.datastar && request.path == DATASTAR_JS_PATH {
152 let use_brotli = compression::accepts_brotli(&parts.headers);
153 let mut header_map = hyper::header::HeaderMap::new();
154 header_map.insert(
155 hyper::header::CONTENT_TYPE,
156 hyper::header::HeaderValue::from_static("application/javascript"),
157 );
158 header_map.insert(
159 hyper::header::CACHE_CONTROL,
160 hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
161 );
162 let body = if use_brotli {
163 header_map.insert(
164 hyper::header::CONTENT_ENCODING,
165 hyper::header::HeaderValue::from_static("br"),
166 );
167 header_map.insert(
168 hyper::header::VARY,
169 hyper::header::HeaderValue::from_static("accept-encoding"),
170 );
171 Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
172 .map_err(|never| match never {})
173 .boxed()
174 } else {
175 Full::new(Bytes::from_static(DATASTAR_JS))
176 .map_err(|never| match never {})
177 .boxed()
178 };
179 log_response(request_id, 200, &header_map, start_time);
180 let logging_body = LoggingBody::new(body, guard);
181 let mut response = hyper::Response::builder()
182 .status(200)
183 .body(logging_body.boxed())?;
184 *response.headers_mut() = header_map;
185 return Ok(response);
186 }
187
188 let sse_cancel_token = engine.sse_cancel_token.clone();
189 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
190
191 let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
195 tokio::join!(async { meta_rx.await.ok() }, async {
196 bridged_body.await.map_err(|e| e.into())
197 });
198
199 let use_brotli = compression::accepts_brotli(&parts.headers);
200
201 match special_response.as_ref().map(|r| &r.body_type) {
203 Some(ResponseBodyType::Normal) | None => {
204 build_normal_response(
206 body_result?,
207 use_brotli,
208 guard,
209 start_time,
210 sse_cancel_token,
211 )
212 .await
213 }
214 Some(ResponseBodyType::Static {
215 root,
216 path,
217 fallback,
218 }) => {
219 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
220 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
221 *static_req.method_mut() = parts.method.clone();
222 *static_req.headers_mut() = parts.headers.clone();
223
224 let res = if let Some(fallback) = fallback {
225 let fp = root.join(fallback);
226 ServeDir::new(root)
227 .fallback(ServeFile::new(fp))
228 .call(static_req)
229 .await?
230 } else {
231 ServeDir::new(root).call(static_req).await?
232 };
233 let (res_parts, body) = res.into_parts();
234 log_response(
235 request_id,
236 res_parts.status.as_u16(),
237 &res_parts.headers,
238 start_time,
239 );
240
241 let bytes = body.collect().await?.to_bytes();
242 let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
243 let logging_body = LoggingBody::new(inner_body, guard);
244 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
245 Ok(res)
246 }
247 Some(ResponseBodyType::ReverseProxy {
248 target_url,
249 headers,
250 preserve_host,
251 strip_prefix,
252 request_body,
253 query,
254 }) => {
255 let body = Full::new(Bytes::from(request_body.clone()));
256 let mut proxy_req = hyper::Request::new(body);
257
258 let path = if let Some(prefix) = strip_prefix {
260 parts
261 .uri
262 .path()
263 .strip_prefix(prefix)
264 .unwrap_or(parts.uri.path())
265 } else {
266 parts.uri.path()
267 };
268
269 let target_uri = {
271 let query_string = if let Some(custom_query) = query {
272 url::form_urlencoded::Serializer::new(String::new())
274 .extend_pairs(custom_query.iter())
275 .finish()
276 } else if let Some(orig_query) = parts.uri.query() {
277 orig_query.to_string()
279 } else {
280 String::new()
281 };
282
283 if query_string.is_empty() {
284 format!("{target_url}{path}")
285 } else {
286 format!("{target_url}{path}?{query_string}")
287 }
288 };
289
290 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
291 *proxy_req.method_mut() = parts.method.clone();
292
293 let mut header_map = parts.headers.clone();
295
296 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
298 header_map.insert(
299 hyper::header::CONTENT_LENGTH,
300 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
301 );
302 }
303
304 for (k, v) in headers {
306 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
307
308 match v {
309 crate::response::HeaderValue::Single(s) => {
310 let header_value = hyper::header::HeaderValue::from_str(s)?;
311 header_map.insert(header_name, header_value);
312 }
313 crate::response::HeaderValue::Multiple(values) => {
314 for value in values {
315 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
316 header_map.append(header_name.clone(), header_value);
317 }
318 }
319 }
320 }
321 }
322
323 if !preserve_host {
325 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
326 if let Some(authority) = target_uri.authority() {
327 header_map.insert(
328 hyper::header::HOST,
329 hyper::header::HeaderValue::from_str(authority.as_ref())?,
330 );
331 }
332 }
333 }
334
335 *proxy_req.headers_mut() = header_map;
336
337 let client =
339 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
340 .build_http();
341
342 match client.request(proxy_req).await {
343 Ok(response) => {
344 let (res_parts, body) = response.into_parts();
345 log_response(
346 request_id,
347 res_parts.status.as_u16(),
348 &res_parts.headers,
349 start_time,
350 );
351
352 let inner_body = body.map_err(|e| e.into()).boxed();
353 let logging_body = LoggingBody::new(inner_body, guard);
354 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
355 Ok(res)
356 }
357 Err(_e) => {
358 let empty_headers = hyper::header::HeaderMap::new();
359 log_response(request_id, 502, &empty_headers, start_time);
360
361 let inner_body = Full::new("Bad Gateway".into())
362 .map_err(|never| match never {})
363 .boxed();
364 let logging_body = LoggingBody::new(inner_body, guard);
365 let response = hyper::Response::builder()
366 .status(502)
367 .body(logging_body.boxed())?;
368 Ok(response)
369 }
370 }
371 }
372 }
373}
374
375async fn build_normal_response(
376 pipeline_result: PipelineResult,
377 use_brotli: bool,
378 guard: RequestGuard,
379 start_time: Instant,
380 sse_cancel_token: CancellationToken,
381) -> HTTPResult {
382 let request_id = guard.request_id();
383 let (inferred_content_type, http_meta, body) = pipeline_result;
384 let status = match (http_meta.status, &body) {
385 (Some(s), _) => s,
386 (None, ResponseTransport::Empty) => 204,
387 (None, _) => 200,
388 };
389 let mut builder = hyper::Response::builder().status(status);
390 let mut header_map = hyper::header::HeaderMap::new();
391
392 let content_type = http_meta
398 .headers
399 .get("content-type")
400 .or(http_meta.headers.get("Content-Type"))
401 .and_then(|hv| match hv {
402 crate::response::HeaderValue::Single(s) => Some(s.clone()),
403 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
404 })
405 .or(inferred_content_type)
406 .or_else(|| {
407 if matches!(body, ResponseTransport::Empty) {
408 None
409 } else {
410 Some("text/html; charset=utf-8".to_string())
411 }
412 });
413
414 if let Some(ref ct) = content_type {
415 header_map.insert(
416 hyper::header::CONTENT_TYPE,
417 hyper::header::HeaderValue::from_str(ct)?,
418 );
419 }
420
421 if use_brotli {
423 header_map.insert(
424 hyper::header::CONTENT_ENCODING,
425 hyper::header::HeaderValue::from_static("br"),
426 );
427 header_map.insert(
428 hyper::header::VARY,
429 hyper::header::HeaderValue::from_static("accept-encoding"),
430 );
431 }
432
433 let is_sse = content_type.as_deref() == Some("text/event-stream");
435 if is_sse {
436 header_map.insert(
437 hyper::header::CACHE_CONTROL,
438 hyper::header::HeaderValue::from_static("no-cache"),
439 );
440 header_map.insert(
441 hyper::header::CONNECTION,
442 hyper::header::HeaderValue::from_static("keep-alive"),
443 );
444 }
445
446 for (k, v) in &http_meta.headers {
447 if k.to_lowercase() != "content-type" {
448 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
449
450 match v {
451 crate::response::HeaderValue::Single(s) => {
452 let header_value = hyper::header::HeaderValue::from_str(s)?;
453 header_map.insert(header_name, header_value);
454 }
455 crate::response::HeaderValue::Multiple(values) => {
456 for value in values {
457 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
458 header_map.append(header_name.clone(), header_value);
459 }
460 }
461 }
462 }
463 }
464 }
465
466 log_response(request_id, status, &header_map, start_time);
467 *builder.headers_mut().unwrap() = header_map;
468
469 let inner_body = match body {
470 ResponseTransport::Empty => Empty::<Bytes>::new()
471 .map_err(|never| match never {})
472 .boxed(),
473 ResponseTransport::Full(bytes) => {
474 if use_brotli {
475 let compressed = compression::compress_full(&bytes)?;
476 Full::new(Bytes::from(compressed))
477 .map_err(|never| match never {})
478 .boxed()
479 } else {
480 Full::new(bytes.into())
481 .map_err(|never| match never {})
482 .boxed()
483 }
484 }
485 ResponseTransport::Stream(rx) => {
486 let byte_stream: Pin<Box<dyn Stream<Item = Vec<u8>> + Send + Sync>> = if is_sse {
488 Box::pin(futures_util::stream::unfold(
490 (ReceiverStream::new(rx), sse_cancel_token),
491 |(mut data_rx, token)| async move {
492 tokio::select! {
493 biased;
494 _ = token.cancelled() => None,
495 item = StreamExt::next(&mut data_rx) => {
496 item.map(|data| (data, (data_rx, token)))
497 }
498 }
499 },
500 ))
501 } else {
502 Box::pin(ReceiverStream::new(rx))
503 };
504
505 if use_brotli {
507 let brotli = compression::BrotliStream::new(byte_stream);
508 BodyExt::boxed(StreamBody::new(brotli))
509 } else {
510 let stream = byte_stream.map(|data| Ok(Frame::data(Bytes::from(data))));
511 BodyExt::boxed(StreamBody::new(stream))
512 }
513 }
514 };
515
516 let logging_body = LoggingBody::new(inner_body, guard);
518 Ok(builder.body(logging_body.boxed())?)
519}