1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use futures_util::StreamExt;
7use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
8use hyper::body::{Bytes, Frame};
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tower::Service;
12use tower_http::services::{ServeDir, ServeFile};
13
14use crate::compression;
15use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
16use crate::request::{resolve_trusted_ip, Request};
17use crate::response::{Response, ResponseBodyType, ResponseTransport};
18use crate::worker::{spawn_eval_thread, PipelineResult};
19
20type BoxError = Box<dyn std::error::Error + Send + Sync>;
21type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
22
23const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.7.js";
24const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.7.js");
25const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.7.js.br");
26
27pub struct AppConfig {
28 pub trusted_proxies: Vec<ipnet::IpNet>,
29 pub datastar: bool,
30 pub dev: bool,
31}
32
33pub async fn handle<B>(
34 engine: Arc<ArcSwap<crate::Engine>>,
35 addr: Option<SocketAddr>,
36 config: Arc<AppConfig>,
37 req: hyper::Request<B>,
38) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
39where
40 B: hyper::body::Body + Unpin + Send + 'static,
41 B::Data: Into<Bytes> + Clone + Send,
42 B::Error: Into<BoxError> + Send,
43{
44 let engine = engine.load_full();
46 match handle_inner(engine, addr, config, req).await {
47 Ok(response) => Ok(response),
48 Err(err) => {
49 eprintln!("Error handling request: {err}");
50 let response = hyper::Response::builder().status(500).body(
51 Full::new(format!("Script error: {err}").into())
52 .map_err(|never| match never {})
53 .boxed(),
54 )?;
55 Ok(response)
56 }
57 }
58}
59
60async fn handle_inner<B>(
61 engine: Arc<crate::Engine>,
62 addr: Option<SocketAddr>,
63 config: Arc<AppConfig>,
64 req: hyper::Request<B>,
65) -> HTTPResult
66where
67 B: hyper::body::Body + Unpin + Send + 'static,
68 B::Data: Into<Bytes> + Clone + Send,
69 B::Error: Into<BoxError> + Send,
70{
71 let (parts, mut body) = req.into_parts();
72
73 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
75
76 tokio::task::spawn(async move {
78 while let Some(frame) = body.frame().await {
79 match frame {
80 Ok(frame) => {
81 if let Some(data) = frame.data_ref() {
82 let bytes: Bytes = (*data).clone().into();
83 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
84 break;
85 }
86 }
87 }
88 Err(err) => {
89 let _ = body_tx.send(Err(err.into())).await;
90 break;
91 }
92 }
93 }
94 });
95
96 let stream = nu_protocol::ByteStream::from_fn(
98 nu_protocol::Span::unknown(),
99 engine.state.signals().clone(),
100 nu_protocol::ByteStreamType::Unknown,
101 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
102 Some(Ok(bytes)) => {
103 buffer.extend_from_slice(&bytes);
104 Ok(true)
105 }
106 Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
107 error: "Body read error".into(),
108 msg: err.to_string(),
109 span: None,
110 help: None,
111 inner: vec![],
112 }),
113 None => Ok(false),
114 },
115 );
116
117 let start_time = Instant::now();
119 let request_id = scru128::new();
120 let guard = RequestGuard::new(request_id);
121
122 let remote_ip = addr.as_ref().map(|a| a.ip());
123 let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
124
125 let request = Request {
126 proto: format!("{:?}", parts.version),
127 method: parts.method.clone(),
128 authority: parts.uri.authority().map(|a| a.to_string()),
129 remote_ip,
130 remote_port: addr.as_ref().map(|a| a.port()),
131 trusted_ip,
132 headers: parts.headers.clone(),
133 uri: parts.uri.clone(),
134 path: parts.uri.path().to_string(),
135 query: parts
136 .uri
137 .query()
138 .map(|v| {
139 url::form_urlencoded::parse(v.as_bytes())
140 .into_owned()
141 .collect()
142 })
143 .unwrap_or_else(std::collections::HashMap::new),
144 };
145
146 log_request(request_id, &request);
148
149 if config.datastar && request.path == DATASTAR_JS_PATH {
151 let use_brotli = compression::accepts_brotli(&parts.headers);
152 let mut header_map = hyper::header::HeaderMap::new();
153 header_map.insert(
154 hyper::header::CONTENT_TYPE,
155 hyper::header::HeaderValue::from_static("application/javascript"),
156 );
157 header_map.insert(
158 hyper::header::CACHE_CONTROL,
159 hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
160 );
161 let body = if use_brotli {
162 header_map.insert(
163 hyper::header::CONTENT_ENCODING,
164 hyper::header::HeaderValue::from_static("br"),
165 );
166 header_map.insert(
167 hyper::header::VARY,
168 hyper::header::HeaderValue::from_static("accept-encoding"),
169 );
170 Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
171 .map_err(|never| match never {})
172 .boxed()
173 } else {
174 Full::new(Bytes::from_static(DATASTAR_JS))
175 .map_err(|never| match never {})
176 .boxed()
177 };
178 log_response(request_id, 200, &header_map, start_time);
179 let logging_body = LoggingBody::new(body, guard);
180 let mut response = hyper::Response::builder()
181 .status(200)
182 .body(logging_body.boxed())?;
183 *response.headers_mut() = header_map;
184 return Ok(response);
185 }
186
187 let reload_token = engine.reload_token.clone();
188 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
189
190 let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
194 tokio::join!(async { meta_rx.await.ok() }, async {
195 bridged_body.await.map_err(|e| e.into())
196 });
197
198 let use_brotli = compression::accepts_brotli(&parts.headers);
199
200 match special_response.as_ref().map(|r| &r.body_type) {
202 Some(ResponseBodyType::Normal) | None => {
203 build_normal_response(body_result?, use_brotli, guard, start_time, reload_token).await
205 }
206 Some(ResponseBodyType::Static {
207 root,
208 path,
209 fallback,
210 }) => {
211 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
212 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
213 *static_req.method_mut() = parts.method.clone();
214 *static_req.headers_mut() = parts.headers.clone();
215
216 let res = if let Some(fallback) = fallback {
217 let fp = root.join(fallback);
218 ServeDir::new(root)
219 .fallback(ServeFile::new(fp))
220 .call(static_req)
221 .await?
222 } else {
223 ServeDir::new(root).call(static_req).await?
224 };
225 let (res_parts, body) = res.into_parts();
226 log_response(
227 request_id,
228 res_parts.status.as_u16(),
229 &res_parts.headers,
230 start_time,
231 );
232
233 let bytes = body.collect().await?.to_bytes();
234 let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
235 let logging_body = LoggingBody::new(inner_body, guard);
236 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
237 Ok(res)
238 }
239 Some(ResponseBodyType::ReverseProxy {
240 target_url,
241 headers,
242 preserve_host,
243 strip_prefix,
244 request_body,
245 query,
246 }) => {
247 let body = Full::new(Bytes::from(request_body.clone()));
248 let mut proxy_req = hyper::Request::new(body);
249
250 let path = if let Some(prefix) = strip_prefix {
252 parts
253 .uri
254 .path()
255 .strip_prefix(prefix)
256 .unwrap_or(parts.uri.path())
257 } else {
258 parts.uri.path()
259 };
260
261 let target_uri = {
263 let query_string = if let Some(custom_query) = query {
264 url::form_urlencoded::Serializer::new(String::new())
266 .extend_pairs(custom_query.iter())
267 .finish()
268 } else if let Some(orig_query) = parts.uri.query() {
269 orig_query.to_string()
271 } else {
272 String::new()
273 };
274
275 if query_string.is_empty() {
276 format!("{target_url}{path}")
277 } else {
278 format!("{target_url}{path}?{query_string}")
279 }
280 };
281
282 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
283 *proxy_req.method_mut() = parts.method.clone();
284
285 let mut header_map = parts.headers.clone();
287
288 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
290 header_map.insert(
291 hyper::header::CONTENT_LENGTH,
292 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
293 );
294 }
295
296 for (k, v) in headers {
298 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
299
300 match v {
301 crate::response::HeaderValue::Single(s) => {
302 let header_value = hyper::header::HeaderValue::from_str(s)?;
303 header_map.insert(header_name, header_value);
304 }
305 crate::response::HeaderValue::Multiple(values) => {
306 for value in values {
307 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
308 header_map.append(header_name.clone(), header_value);
309 }
310 }
311 }
312 }
313 }
314
315 if !preserve_host {
317 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
318 if let Some(authority) = target_uri.authority() {
319 header_map.insert(
320 hyper::header::HOST,
321 hyper::header::HeaderValue::from_str(authority.as_ref())?,
322 );
323 }
324 }
325 }
326
327 *proxy_req.headers_mut() = header_map;
328
329 let client =
331 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
332 .build_http();
333
334 match client.request(proxy_req).await {
335 Ok(response) => {
336 let (res_parts, body) = response.into_parts();
337 log_response(
338 request_id,
339 res_parts.status.as_u16(),
340 &res_parts.headers,
341 start_time,
342 );
343
344 let inner_body = body.map_err(|e| e.into()).boxed();
345 let logging_body = LoggingBody::new(inner_body, guard);
346 let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
347 Ok(res)
348 }
349 Err(_e) => {
350 let empty_headers = hyper::header::HeaderMap::new();
351 log_response(request_id, 502, &empty_headers, start_time);
352
353 let inner_body = Full::new("Bad Gateway".into())
354 .map_err(|never| match never {})
355 .boxed();
356 let logging_body = LoggingBody::new(inner_body, guard);
357 let response = hyper::Response::builder()
358 .status(502)
359 .body(logging_body.boxed())?;
360 Ok(response)
361 }
362 }
363 }
364 }
365}
366
367async fn build_normal_response(
368 pipeline_result: PipelineResult,
369 use_brotli: bool,
370 guard: RequestGuard,
371 start_time: Instant,
372 reload_token: CancellationToken,
373) -> HTTPResult {
374 let request_id = guard.request_id();
375 let (inferred_content_type, http_meta, body) = pipeline_result;
376 let status = match (http_meta.status, &body) {
377 (Some(s), _) => s,
378 (None, ResponseTransport::Empty) => 204,
379 (None, _) => 200,
380 };
381 let mut builder = hyper::Response::builder().status(status);
382 let mut header_map = hyper::header::HeaderMap::new();
383
384 let content_type = http_meta
390 .headers
391 .get("content-type")
392 .or(http_meta.headers.get("Content-Type"))
393 .and_then(|hv| match hv {
394 crate::response::HeaderValue::Single(s) => Some(s.clone()),
395 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
396 })
397 .or(inferred_content_type)
398 .or_else(|| {
399 if matches!(body, ResponseTransport::Empty) {
400 None
401 } else {
402 Some("text/html; charset=utf-8".to_string())
403 }
404 });
405
406 if let Some(ref ct) = content_type {
407 header_map.insert(
408 hyper::header::CONTENT_TYPE,
409 hyper::header::HeaderValue::from_str(ct)?,
410 );
411 }
412
413 if use_brotli {
415 header_map.insert(
416 hyper::header::CONTENT_ENCODING,
417 hyper::header::HeaderValue::from_static("br"),
418 );
419 header_map.insert(
420 hyper::header::VARY,
421 hyper::header::HeaderValue::from_static("accept-encoding"),
422 );
423 }
424
425 let is_sse = content_type.as_deref() == Some("text/event-stream");
427 if is_sse {
428 header_map.insert(
429 hyper::header::CACHE_CONTROL,
430 hyper::header::HeaderValue::from_static("no-cache"),
431 );
432 header_map.insert(
433 hyper::header::CONNECTION,
434 hyper::header::HeaderValue::from_static("keep-alive"),
435 );
436 }
437
438 for (k, v) in &http_meta.headers {
439 if k.to_lowercase() != "content-type" {
440 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
441
442 match v {
443 crate::response::HeaderValue::Single(s) => {
444 let header_value = hyper::header::HeaderValue::from_str(s)?;
445 header_map.insert(header_name, header_value);
446 }
447 crate::response::HeaderValue::Multiple(values) => {
448 for value in values {
449 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
450 header_map.append(header_name.clone(), header_value);
451 }
452 }
453 }
454 }
455 }
456 }
457
458 log_response(request_id, status, &header_map, start_time);
459 *builder.headers_mut().unwrap() = header_map;
460
461 let inner_body = match body {
462 ResponseTransport::Empty => Empty::<Bytes>::new()
463 .map_err(|never| match never {})
464 .boxed(),
465 ResponseTransport::Full(bytes) => {
466 if use_brotli {
467 let compressed = compression::compress_full(&bytes)?;
468 Full::new(Bytes::from(compressed))
469 .map_err(|never| match never {})
470 .boxed()
471 } else {
472 Full::new(bytes.into())
473 .map_err(|never| match never {})
474 .boxed()
475 }
476 }
477 ResponseTransport::Stream(rx) => {
478 if use_brotli {
479 compression::compress_stream(rx)
480 } else if is_sse {
481 let stream = futures_util::stream::try_unfold(
483 (ReceiverStream::new(rx), reload_token),
484 |(mut data_rx, token)| async move {
485 tokio::select! {
486 biased;
487 _ = token.cancelled() => {
488 Err(std::io::Error::other("reload").into())
489 }
490 item = StreamExt::next(&mut data_rx) => {
491 match item {
492 Some(data) => Ok(Some((Frame::data(Bytes::from(data)), (data_rx, token)))),
493 None => Ok(None),
494 }
495 }
496 }
497 },
498 );
499 BodyExt::boxed(StreamBody::new(stream))
500 } else {
501 let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
502 BodyExt::boxed(StreamBody::new(stream))
503 }
504 }
505 };
506
507 let logging_body = LoggingBody::new(inner_body, guard);
509 Ok(builder.body(logging_body.boxed())?)
510}