Skip to main content

http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Instant;
5
6use arc_swap::ArcSwap;
7use futures_util::{Stream, StreamExt};
8use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
9use hyper::body::{Bytes, Frame};
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_util::sync::CancellationToken;
12use tower::Service;
13use tower_http::services::{ServeDir, ServeFile};
14
15use nu_protocol::shell_error::generic::GenericError;
16
17use crate::compression;
18use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
19use crate::request::{resolve_trusted_ip, Request};
20use crate::response::{Response, ResponseBodyType, ResponseTransport};
21use crate::worker::{spawn_eval_thread, PipelineResult};
22
23type BoxError = Box<dyn std::error::Error + Send + Sync>;
24type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
25
26const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.8.js";
27const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js");
28const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js.br");
29
30pub struct AppConfig {
31    pub trusted_proxies: Vec<ipnet::IpNet>,
32    pub datastar: bool,
33    pub dev: bool,
34}
35
36pub async fn handle<B>(
37    engine: Arc<ArcSwap<crate::Engine>>,
38    addr: Option<SocketAddr>,
39    config: Arc<AppConfig>,
40    req: hyper::Request<B>,
41) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
42where
43    B: hyper::body::Body + Unpin + Send + 'static,
44    B::Data: Into<Bytes> + Clone + Send,
45    B::Error: Into<BoxError> + Send,
46{
47    // Load current engine snapshot - lock-free atomic operation
48    let engine = engine.load_full();
49    match handle_inner(engine, addr, config, req).await {
50        Ok(response) => Ok(response),
51        Err(err) => {
52            eprintln!("Error handling request: {err}");
53            let response = hyper::Response::builder().status(500).body(
54                Full::new(format!("Script error: {err}").into())
55                    .map_err(|never| match never {})
56                    .boxed(),
57            )?;
58            Ok(response)
59        }
60    }
61}
62
63async fn handle_inner<B>(
64    engine: Arc<crate::Engine>,
65    addr: Option<SocketAddr>,
66    config: Arc<AppConfig>,
67    req: hyper::Request<B>,
68) -> HTTPResult
69where
70    B: hyper::body::Body + Unpin + Send + 'static,
71    B::Data: Into<Bytes> + Clone + Send,
72    B::Error: Into<BoxError> + Send,
73{
74    let (parts, mut body) = req.into_parts();
75
76    // Create channels for request body streaming
77    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
78
79    // Spawn task to read request body frames
80    tokio::task::spawn(async move {
81        while let Some(frame) = body.frame().await {
82            match frame {
83                Ok(frame) => {
84                    if let Some(data) = frame.data_ref() {
85                        let bytes: Bytes = (*data).clone().into();
86                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
87                            break;
88                        }
89                    }
90                }
91                Err(err) => {
92                    let _ = body_tx.send(Err(err.into())).await;
93                    break;
94                }
95            }
96        }
97    });
98
99    // Create ByteStream for Nu pipeline
100    let stream = nu_protocol::ByteStream::from_fn(
101        nu_protocol::Span::unknown(),
102        engine.state.signals().clone(),
103        nu_protocol::ByteStreamType::Unknown,
104        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
105            Some(Ok(bytes)) => {
106                buffer.extend_from_slice(&bytes);
107                Ok(true)
108            }
109            Some(Err(err)) => Err(nu_protocol::ShellError::Generic(
110                GenericError::new_internal("Body read error", err.to_string()),
111            )),
112            None => Ok(false),
113        },
114    );
115
116    // Generate request ID and guard for logging
117    let start_time = Instant::now();
118    let request_id = scru128::new();
119    let guard = RequestGuard::new(request_id);
120
121    let remote_ip = addr.as_ref().map(|a| a.ip());
122    let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
123
124    let request = Request {
125        proto: format!("{:?}", parts.version),
126        method: parts.method.clone(),
127        authority: parts.uri.authority().map(|a| a.to_string()),
128        remote_ip,
129        remote_port: addr.as_ref().map(|a| a.port()),
130        trusted_ip,
131        headers: parts.headers.clone(),
132        uri: parts.uri.clone(),
133        path: parts.uri.path().to_string(),
134        query: parts
135            .uri
136            .query()
137            .map(|v| {
138                url::form_urlencoded::parse(v.as_bytes())
139                    .into_owned()
140                    .collect()
141            })
142            .unwrap_or_else(std::collections::HashMap::new),
143    };
144
145    // Phase 1: Log request
146    log_request(request_id, &request);
147
148    // Built-in route: serve embedded Datastar JS bundle (requires --datastar flag)
149    if config.datastar && request.path == DATASTAR_JS_PATH {
150        let use_brotli = compression::accepts_brotli(&parts.headers);
151        let mut header_map = hyper::header::HeaderMap::new();
152        header_map.insert(
153            hyper::header::CONTENT_TYPE,
154            hyper::header::HeaderValue::from_static("application/javascript"),
155        );
156        header_map.insert(
157            hyper::header::CACHE_CONTROL,
158            hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
159        );
160        let body = if use_brotli {
161            header_map.insert(
162                hyper::header::CONTENT_ENCODING,
163                hyper::header::HeaderValue::from_static("br"),
164            );
165            header_map.insert(
166                hyper::header::VARY,
167                hyper::header::HeaderValue::from_static("accept-encoding"),
168            );
169            Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
170                .map_err(|never| match never {})
171                .boxed()
172        } else {
173            Full::new(Bytes::from_static(DATASTAR_JS))
174                .map_err(|never| match never {})
175                .boxed()
176        };
177        log_response(request_id, 200, &header_map, start_time);
178        let logging_body = LoggingBody::new(body, guard);
179        let mut response = hyper::Response::builder()
180            .status(200)
181            .body(logging_body.boxed())?;
182        *response.headers_mut() = header_map;
183        return Ok(response);
184    }
185
186    let sse_cancel_token = engine.sse_cancel_token.clone();
187    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
188
189    // Wait for both:
190    // 1. Special response (from .static or .reverse-proxy) - None if normal response
191    // 2. Body pipeline result (includes http.response metadata for normal responses)
192    let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
193        tokio::join!(async { meta_rx.await.ok() }, async {
194            bridged_body.await.map_err(|e| e.into())
195        });
196
197    let use_brotli = compression::accepts_brotli(&parts.headers);
198
199    // Check if we got a special response (.static or .reverse-proxy)
200    match special_response.as_ref().map(|r| &r.body_type) {
201        Some(ResponseBodyType::Normal) | None => {
202            // Normal response - use metadata from pipeline
203            build_normal_response(
204                body_result?,
205                use_brotli,
206                guard,
207                start_time,
208                sse_cancel_token,
209            )
210            .await
211        }
212        Some(ResponseBodyType::Static {
213            root,
214            path,
215            fallback,
216        }) => {
217            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
218            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
219            *static_req.method_mut() = parts.method.clone();
220            *static_req.headers_mut() = parts.headers.clone();
221
222            let res = if let Some(fallback) = fallback {
223                let fp = root.join(fallback);
224                ServeDir::new(root)
225                    .fallback(ServeFile::new(fp))
226                    .call(static_req)
227                    .await?
228            } else {
229                ServeDir::new(root).call(static_req).await?
230            };
231            let (res_parts, body) = res.into_parts();
232            log_response(
233                request_id,
234                res_parts.status.as_u16(),
235                &res_parts.headers,
236                start_time,
237            );
238
239            let bytes = body.collect().await?.to_bytes();
240            let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
241            let logging_body = LoggingBody::new(inner_body, guard);
242            let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
243            Ok(res)
244        }
245        Some(ResponseBodyType::ReverseProxy {
246            target_url,
247            headers,
248            preserve_host,
249            strip_prefix,
250            request_body,
251            query,
252        }) => {
253            let body = Full::new(Bytes::from(request_body.clone()));
254            let mut proxy_req = hyper::Request::new(body);
255
256            // Handle strip_prefix
257            let path = if let Some(prefix) = strip_prefix {
258                parts
259                    .uri
260                    .path()
261                    .strip_prefix(prefix)
262                    .unwrap_or(parts.uri.path())
263            } else {
264                parts.uri.path()
265            };
266
267            // Ensure the path starts with '/' to prevent SSRF via URL
268            // authority confusion (e.g. path "@evil.com/" being concatenated
269            // to produce "http://backend@evil.com/")
270            let path = if path.starts_with('/') {
271                path.to_string()
272            } else {
273                format!("/{path}")
274            };
275
276            // Build target URI
277            let target_uri = {
278                let query_string = if let Some(custom_query) = query {
279                    // Use custom query - convert HashMap to query string
280                    url::form_urlencoded::Serializer::new(String::new())
281                        .extend_pairs(custom_query.iter())
282                        .finish()
283                } else if let Some(orig_query) = parts.uri.query() {
284                    // Use original query string
285                    orig_query.to_string()
286                } else {
287                    String::new()
288                };
289
290                if query_string.is_empty() {
291                    format!("{target_url}{path}")
292                } else {
293                    format!("{target_url}{path}?{query_string}")
294                }
295            };
296
297            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
298            *proxy_req.method_mut() = parts.method.clone();
299
300            // Copy original headers
301            let mut header_map = parts.headers.clone();
302
303            // Update Content-Length to match the new body
304            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
305                header_map.insert(
306                    hyper::header::CONTENT_LENGTH,
307                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
308                );
309            }
310
311            // Add custom headers
312            for (k, v) in headers {
313                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
314
315                match v {
316                    crate::response::HeaderValue::Single(s) => {
317                        let header_value = hyper::header::HeaderValue::from_str(s)?;
318                        header_map.insert(header_name, header_value);
319                    }
320                    crate::response::HeaderValue::Multiple(values) => {
321                        for value in values {
322                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
323                                header_map.append(header_name.clone(), header_value);
324                            }
325                        }
326                    }
327                }
328            }
329
330            // Handle preserve_host
331            if !preserve_host {
332                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
333                    if let Some(authority) = target_uri.authority() {
334                        header_map.insert(
335                            hyper::header::HOST,
336                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
337                        );
338                    }
339                }
340            }
341
342            *proxy_req.headers_mut() = header_map;
343
344            // Create a simple HTTP client and forward the request
345            let client =
346                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
347                    .build_http();
348
349            match client.request(proxy_req).await {
350                Ok(response) => {
351                    let (res_parts, body) = response.into_parts();
352                    log_response(
353                        request_id,
354                        res_parts.status.as_u16(),
355                        &res_parts.headers,
356                        start_time,
357                    );
358
359                    let inner_body = body.map_err(|e| e.into()).boxed();
360                    let logging_body = LoggingBody::new(inner_body, guard);
361                    let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
362                    Ok(res)
363                }
364                Err(_e) => {
365                    let empty_headers = hyper::header::HeaderMap::new();
366                    log_response(request_id, 502, &empty_headers, start_time);
367
368                    let inner_body = Full::new("Bad Gateway".into())
369                        .map_err(|never| match never {})
370                        .boxed();
371                    let logging_body = LoggingBody::new(inner_body, guard);
372                    let response = hyper::Response::builder()
373                        .status(502)
374                        .body(logging_body.boxed())?;
375                    Ok(response)
376                }
377            }
378        }
379    }
380}
381
382async fn build_normal_response(
383    pipeline_result: PipelineResult,
384    use_brotli: bool,
385    guard: RequestGuard,
386    start_time: Instant,
387    sse_cancel_token: CancellationToken,
388) -> HTTPResult {
389    let request_id = guard.request_id();
390    let (inferred_content_type, http_meta, body) = pipeline_result;
391    let status = match (http_meta.status, &body) {
392        (Some(s), _) => s,
393        (None, ResponseTransport::Empty) => 204,
394        (None, _) => 200,
395    };
396    let mut builder = hyper::Response::builder().status(status);
397    let mut header_map = hyper::header::HeaderMap::new();
398
399    // Content-type precedence:
400    // 1. Explicit in http.response headers
401    // 2. Pipeline metadata (from `to json`, etc.)
402    // 3. Inferred: record->json, binary->octet-stream, list/stream of records->ndjson, empty->None
403    // 4. Default: text/html
404    let content_type = http_meta
405        .headers
406        .get("content-type")
407        .or(http_meta.headers.get("Content-Type"))
408        .and_then(|hv| match hv {
409            crate::response::HeaderValue::Single(s) => Some(s.clone()),
410            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
411        })
412        .or(inferred_content_type)
413        .or_else(|| {
414            if matches!(body, ResponseTransport::Empty) {
415                None
416            } else {
417                Some("text/html; charset=utf-8".to_string())
418            }
419        });
420
421    if let Some(ref ct) = content_type {
422        header_map.insert(
423            hyper::header::CONTENT_TYPE,
424            hyper::header::HeaderValue::from_str(ct)?,
425        );
426    }
427
428    // Add compression headers if using brotli
429    if use_brotli {
430        header_map.insert(
431            hyper::header::CONTENT_ENCODING,
432            hyper::header::HeaderValue::from_static("br"),
433        );
434        header_map.insert(
435            hyper::header::VARY,
436            hyper::header::HeaderValue::from_static("accept-encoding"),
437        );
438    }
439
440    // Add SSE-required headers for event streams
441    let is_sse = content_type.as_deref() == Some("text/event-stream");
442    if is_sse {
443        header_map.insert(
444            hyper::header::CACHE_CONTROL,
445            hyper::header::HeaderValue::from_static("no-cache"),
446        );
447        header_map.insert(
448            hyper::header::CONNECTION,
449            hyper::header::HeaderValue::from_static("keep-alive"),
450        );
451    }
452
453    for (k, v) in &http_meta.headers {
454        if k.to_lowercase() != "content-type" {
455            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
456
457            match v {
458                crate::response::HeaderValue::Single(s) => {
459                    let header_value = hyper::header::HeaderValue::from_str(s)?;
460                    header_map.insert(header_name, header_value);
461                }
462                crate::response::HeaderValue::Multiple(values) => {
463                    for value in values {
464                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
465                            header_map.append(header_name.clone(), header_value);
466                        }
467                    }
468                }
469            }
470        }
471    }
472
473    log_response(request_id, status, &header_map, start_time);
474    *builder.headers_mut().unwrap() = header_map;
475
476    let inner_body = match body {
477        ResponseTransport::Empty => Empty::<Bytes>::new()
478            .map_err(|never| match never {})
479            .boxed(),
480        ResponseTransport::Full(bytes) => {
481            if use_brotli {
482                let compressed = compression::compress_full(&bytes)?;
483                Full::new(Bytes::from(compressed))
484                    .map_err(|never| match never {})
485                    .boxed()
486            } else {
487                Full::new(bytes.into())
488                    .map_err(|never| match never {})
489                    .boxed()
490            }
491        }
492        ResponseTransport::Stream(rx) => {
493            // Layer 1: base byte stream, with reload-abort for SSE
494            let byte_stream: Pin<Box<dyn Stream<Item = Vec<u8>> + Send + Sync>> = if is_sse {
495                // SSE streams abort on cancellation (reload or shutdown)
496                Box::pin(futures_util::stream::unfold(
497                    (ReceiverStream::new(rx), sse_cancel_token),
498                    |(mut data_rx, token)| async move {
499                        tokio::select! {
500                            biased;
501                            _ = token.cancelled() => None,
502                            item = StreamExt::next(&mut data_rx) => {
503                                item.map(|data| (data, (data_rx, token)))
504                            }
505                        }
506                    },
507                ))
508            } else {
509                Box::pin(ReceiverStream::new(rx))
510            };
511
512            // Layer 2: optionally compress, then frame
513            if use_brotli {
514                let brotli = compression::BrotliStream::new(byte_stream);
515                BodyExt::boxed(StreamBody::new(brotli))
516            } else {
517                let stream = byte_stream.map(|data| Ok(Frame::data(Bytes::from(data))));
518                BodyExt::boxed(StreamBody::new(stream))
519            }
520        }
521    };
522
523    // Wrap with LoggingBody for phase 3 (complete) logging
524    let logging_body = LoggingBody::new(inner_body, guard);
525    Ok(builder.body(logging_body.boxed())?)
526}