1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Instant;
5
6use arc_swap::ArcSwap;
7use futures_util::{Stream, StreamExt};
8use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
9use hyper::body::{Bytes, Frame};
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_util::sync::CancellationToken;
12use tower::Service;
13use tower_http::services::{ServeDir, ServeFile};
14
15use crate::compression;
16use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
17use crate::request::{resolve_trusted_ip, Request};
18use crate::response::{Response, ResponseBodyType, ResponseTransport};
19use crate::worker::{spawn_eval_thread, PipelineResult};
20
21type BoxError = Box<dyn std::error::Error + Send + Sync>;
22type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
23
24const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.8.js";
25const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js");
26const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js.br");
27
28pub struct AppConfig {
29 pub trusted_proxies: Vec<ipnet::IpNet>,
30 pub datastar: bool,
31 pub dev: bool,
32}
33
34pub async fn handle<B>(
35 engine: Arc<ArcSwap<crate::Engine>>,
36 addr: Option<SocketAddr>,
37 config: Arc<AppConfig>,
38 req: hyper::Request<B>,
39) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
40where
41 B: hyper::body::Body + Unpin + Send + 'static,
42 B::Data: Into<Bytes> + Clone + Send,
43 B::Error: Into<BoxError> + Send,
44{
45 let engine = engine.load_full();
47 match handle_inner(engine, addr, config, req).await {
48 Ok(response) => Ok(response),
49 Err(err) => {
50 eprintln!("Error handling request: {err}");
51 let response = hyper::Response::builder().status(500).body(
52 Full::new(format!("Script error: {err}").into())
53 .map_err(|never| match never {})
54 .boxed(),
55 )?;
56 Ok(response)
57 }
58 }
59}
60
61async fn handle_inner<B>(
62 engine: Arc<crate::Engine>,
63 addr: Option<SocketAddr>,
64 config: Arc<AppConfig>,
65 req: hyper::Request<B>,
66) -> HTTPResult
67where
68 B: hyper::body::Body + Unpin + Send + 'static,
69 B::Data: Into<Bytes> + Clone + Send,
70 B::Error: Into<BoxError> + Send,
71{
72 let (parts, mut body) = req.into_parts();
73
74 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
76
77 tokio::task::spawn(async move {
79 while let Some(frame) = body.frame().await {
80 match frame {
81 Ok(frame) => {
82 if let Some(data) = frame.data_ref() {
83 let bytes: Bytes = (*data).clone().into();
84 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
85 break;
86 }
87 }
88 }
89 Err(err) => {
90 let _ = body_tx.send(Err(err.into())).await;
91 break;
92 }
93 }
94 }
95 });
96
97 let stream = nu_protocol::ByteStream::from_fn(
99 nu_protocol::Span::unknown(),
100 engine.state.signals().clone(),
101 nu_protocol::ByteStreamType::Unknown,
102 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
103 Some(Ok(bytes)) => {
104 buffer.extend_from_slice(&bytes);
105 Ok(true)
106 }
107 Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
108 error: "Body read error".into(),
109 msg: err.to_string(),
110 span: None,
111 help: None,
112 inner: vec![],
113 }),
114 None => Ok(false),
115 },
116 );
117
118 let start_time = Instant::now();
120 let request_id = scru128::new();
121 let guard = RequestGuard::new(request_id);
122
123 let remote_ip = addr.as_ref().map(|a| a.ip());
124 let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
125
126 let request = Request {
127 proto: format!("{:?}", parts.version),
128 method: parts.method.clone(),
129 authority: parts.uri.authority().map(|a| a.to_string()),
130 remote_ip,
131 remote_port: addr.as_ref().map(|a| a.port()),
132 trusted_ip,
133 headers: parts.headers.clone(),
134 uri: parts.uri.clone(),
135 path: parts.uri.path().to_string(),
136 query: parts
137 .uri
138 .query()
139 .map(|v| {
140 url::form_urlencoded::parse(v.as_bytes())
141 .into_owned()
142 .collect()
143 })
144 .unwrap_or_else(std::collections::HashMap::new),
145 };
146
147 log_request(request_id, &request);
149
150 if config.datastar && request.path == DATASTAR_JS_PATH {
152 let use_brotli = compression::accepts_brotli(&parts.headers);
153 let mut header_map = hyper::header::HeaderMap::new();
154 header_map.insert(
155 hyper::header::CONTENT_TYPE,
156 hyper::header::HeaderValue::from_static("application/javascript"),
157 );
158 header_map.insert(
159 hyper::header::CACHE_CONTROL,
160 hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
161 );
162 let body = if use_brotli {
163 header_map.insert(
164 hyper::header::CONTENT_ENCODING,
165 hyper::header::HeaderValue::from_static("br"),
166 );
167 header_map.insert(
168 hyper::header::VARY,
169 hyper::header::HeaderValue::from_static("accept-encoding"),
170 );
171 Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
172 .map_err(|never| match never {})
173 .boxed()
174 } else {
175 Full::new(Bytes::from_static(DATASTAR_JS))
176 .map_err(|never| match never {})
177 .boxed()
178 };
179 log_response(request_id, 200, &header_map, start_time);
180 let logging_body = LoggingBody::new(body, guard);
181 let mut response = hyper::Response::builder()
182 .status(200)
183 .body(logging_body.boxed())?;
184 *response.headers_mut() = header_map;
185 return Ok(response);
186 }
187
188 let reload_token = engine.reload_token.clone();
189 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
190
191 let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
195 tokio::join!(async { meta_rx.await.ok() }, async {
196 bridged_body.await.map_err(|e| e.into())
197 });
198
199 let use_brotli = compression::accepts_brotli(&parts.headers);
200
201 match special_response.as_ref().map(|r| &r.body_type) {
203 Some(ResponseBodyType::Normal) | None => {
204 build_normal_response(body_result?, use_brotli, guard, start_time, reload_token).await
206 }
207 Some(ResponseBodyType::Static {
208 root,
209 path,
210 fallback,
211 }) => {
212 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
213 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
214 *static_req.method_mut() = parts.method.clone();
215 *static_req.headers_mut() = parts.headers.clone();
216
217 let res = if let Some(fallback) = fallback {
218 let fp = root.join(fallback);
219 ServeDir::new(root)
220 .fallback(ServeFile::new(fp))
221 .call(static_req)
222 .await?
223 } else {
224 ServeDir::new(root).call(static_req).await?
225 };
226 let (res_parts, body) = res.into_parts();
227 log_response(
228 request_id,
229 res_parts.status.as_u16(),
230 &res_parts.headers,
231 start_time,
232 );
233
234 let bytes = body.collect().await?.to_bytes();
235 let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
236 let logging_body = LoggingBody::new(inner_body, guard);
237 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
238 Ok(res)
239 }
240 Some(ResponseBodyType::ReverseProxy {
241 target_url,
242 headers,
243 preserve_host,
244 strip_prefix,
245 request_body,
246 query,
247 }) => {
248 let body = Full::new(Bytes::from(request_body.clone()));
249 let mut proxy_req = hyper::Request::new(body);
250
251 let path = if let Some(prefix) = strip_prefix {
253 parts
254 .uri
255 .path()
256 .strip_prefix(prefix)
257 .unwrap_or(parts.uri.path())
258 } else {
259 parts.uri.path()
260 };
261
262 let target_uri = {
264 let query_string = if let Some(custom_query) = query {
265 url::form_urlencoded::Serializer::new(String::new())
267 .extend_pairs(custom_query.iter())
268 .finish()
269 } else if let Some(orig_query) = parts.uri.query() {
270 orig_query.to_string()
272 } else {
273 String::new()
274 };
275
276 if query_string.is_empty() {
277 format!("{target_url}{path}")
278 } else {
279 format!("{target_url}{path}?{query_string}")
280 }
281 };
282
283 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
284 *proxy_req.method_mut() = parts.method.clone();
285
286 let mut header_map = parts.headers.clone();
288
289 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
291 header_map.insert(
292 hyper::header::CONTENT_LENGTH,
293 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
294 );
295 }
296
297 for (k, v) in headers {
299 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
300
301 match v {
302 crate::response::HeaderValue::Single(s) => {
303 let header_value = hyper::header::HeaderValue::from_str(s)?;
304 header_map.insert(header_name, header_value);
305 }
306 crate::response::HeaderValue::Multiple(values) => {
307 for value in values {
308 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
309 header_map.append(header_name.clone(), header_value);
310 }
311 }
312 }
313 }
314 }
315
316 if !preserve_host {
318 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
319 if let Some(authority) = target_uri.authority() {
320 header_map.insert(
321 hyper::header::HOST,
322 hyper::header::HeaderValue::from_str(authority.as_ref())?,
323 );
324 }
325 }
326 }
327
328 *proxy_req.headers_mut() = header_map;
329
330 let client =
332 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
333 .build_http();
334
335 match client.request(proxy_req).await {
336 Ok(response) => {
337 let (res_parts, body) = response.into_parts();
338 log_response(
339 request_id,
340 res_parts.status.as_u16(),
341 &res_parts.headers,
342 start_time,
343 );
344
345 let inner_body = body.map_err(|e| e.into()).boxed();
346 let logging_body = LoggingBody::new(inner_body, guard);
347 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
348 Ok(res)
349 }
350 Err(_e) => {
351 let empty_headers = hyper::header::HeaderMap::new();
352 log_response(request_id, 502, &empty_headers, start_time);
353
354 let inner_body = Full::new("Bad Gateway".into())
355 .map_err(|never| match never {})
356 .boxed();
357 let logging_body = LoggingBody::new(inner_body, guard);
358 let response = hyper::Response::builder()
359 .status(502)
360 .body(logging_body.boxed())?;
361 Ok(response)
362 }
363 }
364 }
365 }
366}
367
368async fn build_normal_response(
369 pipeline_result: PipelineResult,
370 use_brotli: bool,
371 guard: RequestGuard,
372 start_time: Instant,
373 reload_token: CancellationToken,
374) -> HTTPResult {
375 let request_id = guard.request_id();
376 let (inferred_content_type, http_meta, body) = pipeline_result;
377 let status = match (http_meta.status, &body) {
378 (Some(s), _) => s,
379 (None, ResponseTransport::Empty) => 204,
380 (None, _) => 200,
381 };
382 let mut builder = hyper::Response::builder().status(status);
383 let mut header_map = hyper::header::HeaderMap::new();
384
385 let content_type = http_meta
391 .headers
392 .get("content-type")
393 .or(http_meta.headers.get("Content-Type"))
394 .and_then(|hv| match hv {
395 crate::response::HeaderValue::Single(s) => Some(s.clone()),
396 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
397 })
398 .or(inferred_content_type)
399 .or_else(|| {
400 if matches!(body, ResponseTransport::Empty) {
401 None
402 } else {
403 Some("text/html; charset=utf-8".to_string())
404 }
405 });
406
407 if let Some(ref ct) = content_type {
408 header_map.insert(
409 hyper::header::CONTENT_TYPE,
410 hyper::header::HeaderValue::from_str(ct)?,
411 );
412 }
413
414 if use_brotli {
416 header_map.insert(
417 hyper::header::CONTENT_ENCODING,
418 hyper::header::HeaderValue::from_static("br"),
419 );
420 header_map.insert(
421 hyper::header::VARY,
422 hyper::header::HeaderValue::from_static("accept-encoding"),
423 );
424 }
425
426 let is_sse = content_type.as_deref() == Some("text/event-stream");
428 if is_sse {
429 header_map.insert(
430 hyper::header::CACHE_CONTROL,
431 hyper::header::HeaderValue::from_static("no-cache"),
432 );
433 header_map.insert(
434 hyper::header::CONNECTION,
435 hyper::header::HeaderValue::from_static("keep-alive"),
436 );
437 }
438
439 for (k, v) in &http_meta.headers {
440 if k.to_lowercase() != "content-type" {
441 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
442
443 match v {
444 crate::response::HeaderValue::Single(s) => {
445 let header_value = hyper::header::HeaderValue::from_str(s)?;
446 header_map.insert(header_name, header_value);
447 }
448 crate::response::HeaderValue::Multiple(values) => {
449 for value in values {
450 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
451 header_map.append(header_name.clone(), header_value);
452 }
453 }
454 }
455 }
456 }
457 }
458
459 log_response(request_id, status, &header_map, start_time);
460 *builder.headers_mut().unwrap() = header_map;
461
462 let inner_body = match body {
463 ResponseTransport::Empty => Empty::<Bytes>::new()
464 .map_err(|never| match never {})
465 .boxed(),
466 ResponseTransport::Full(bytes) => {
467 if use_brotli {
468 let compressed = compression::compress_full(&bytes)?;
469 Full::new(Bytes::from(compressed))
470 .map_err(|never| match never {})
471 .boxed()
472 } else {
473 Full::new(bytes.into())
474 .map_err(|never| match never {})
475 .boxed()
476 }
477 }
478 ResponseTransport::Stream(rx) => {
479 let byte_stream: Pin<Box<dyn Stream<Item = Vec<u8>> + Send + Sync>> = if is_sse {
481 Box::pin(futures_util::stream::unfold(
483 (ReceiverStream::new(rx), reload_token),
484 |(mut data_rx, token)| async move {
485 tokio::select! {
486 biased;
487 _ = token.cancelled() => None,
488 item = StreamExt::next(&mut data_rx) => {
489 item.map(|data| (data, (data_rx, token)))
490 }
491 }
492 },
493 ))
494 } else {
495 Box::pin(ReceiverStream::new(rx))
496 };
497
498 if use_brotli {
500 let brotli = compression::BrotliStream::new(byte_stream);
501 BodyExt::boxed(StreamBody::new(brotli))
502 } else {
503 let stream = byte_stream.map(|data| Ok(Frame::data(Bytes::from(data))));
504 BodyExt::boxed(StreamBody::new(stream))
505 }
506 }
507 };
508
509 let logging_body = LoggingBody::new(inner_body, guard);
511 Ok(builder.body(logging_body.boxed())?)
512}