Skip to main content

http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Instant;
5
6use arc_swap::ArcSwap;
7use futures_util::{Stream, StreamExt};
8use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
9use hyper::body::{Bytes, Frame};
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_util::sync::CancellationToken;
12use tower::Service;
13use tower_http::services::{ServeDir, ServeFile};
14
15use crate::compression;
16use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
17use crate::request::{resolve_trusted_ip, Request};
18use crate::response::{Response, ResponseBodyType, ResponseTransport};
19use crate::worker::{spawn_eval_thread, PipelineResult};
20
21type BoxError = Box<dyn std::error::Error + Send + Sync>;
22type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
23
24const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.8.js";
25const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js");
26const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js.br");
27
28pub struct AppConfig {
29    pub trusted_proxies: Vec<ipnet::IpNet>,
30    pub datastar: bool,
31    pub dev: bool,
32}
33
34pub async fn handle<B>(
35    engine: Arc<ArcSwap<crate::Engine>>,
36    addr: Option<SocketAddr>,
37    config: Arc<AppConfig>,
38    req: hyper::Request<B>,
39) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
40where
41    B: hyper::body::Body + Unpin + Send + 'static,
42    B::Data: Into<Bytes> + Clone + Send,
43    B::Error: Into<BoxError> + Send,
44{
45    // Load current engine snapshot - lock-free atomic operation
46    let engine = engine.load_full();
47    match handle_inner(engine, addr, config, req).await {
48        Ok(response) => Ok(response),
49        Err(err) => {
50            eprintln!("Error handling request: {err}");
51            let response = hyper::Response::builder().status(500).body(
52                Full::new(format!("Script error: {err}").into())
53                    .map_err(|never| match never {})
54                    .boxed(),
55            )?;
56            Ok(response)
57        }
58    }
59}
60
61async fn handle_inner<B>(
62    engine: Arc<crate::Engine>,
63    addr: Option<SocketAddr>,
64    config: Arc<AppConfig>,
65    req: hyper::Request<B>,
66) -> HTTPResult
67where
68    B: hyper::body::Body + Unpin + Send + 'static,
69    B::Data: Into<Bytes> + Clone + Send,
70    B::Error: Into<BoxError> + Send,
71{
72    let (parts, mut body) = req.into_parts();
73
74    // Create channels for request body streaming
75    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
76
77    // Spawn task to read request body frames
78    tokio::task::spawn(async move {
79        while let Some(frame) = body.frame().await {
80            match frame {
81                Ok(frame) => {
82                    if let Some(data) = frame.data_ref() {
83                        let bytes: Bytes = (*data).clone().into();
84                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
85                            break;
86                        }
87                    }
88                }
89                Err(err) => {
90                    let _ = body_tx.send(Err(err.into())).await;
91                    break;
92                }
93            }
94        }
95    });
96
97    // Create ByteStream for Nu pipeline
98    let stream = nu_protocol::ByteStream::from_fn(
99        nu_protocol::Span::unknown(),
100        engine.state.signals().clone(),
101        nu_protocol::ByteStreamType::Unknown,
102        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
103            Some(Ok(bytes)) => {
104                buffer.extend_from_slice(&bytes);
105                Ok(true)
106            }
107            Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
108                error: "Body read error".into(),
109                msg: err.to_string(),
110                span: None,
111                help: None,
112                inner: vec![],
113            }),
114            None => Ok(false),
115        },
116    );
117
118    // Generate request ID and guard for logging
119    let start_time = Instant::now();
120    let request_id = scru128::new();
121    let guard = RequestGuard::new(request_id);
122
123    let remote_ip = addr.as_ref().map(|a| a.ip());
124    let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
125
126    let request = Request {
127        proto: format!("{:?}", parts.version),
128        method: parts.method.clone(),
129        authority: parts.uri.authority().map(|a| a.to_string()),
130        remote_ip,
131        remote_port: addr.as_ref().map(|a| a.port()),
132        trusted_ip,
133        headers: parts.headers.clone(),
134        uri: parts.uri.clone(),
135        path: parts.uri.path().to_string(),
136        query: parts
137            .uri
138            .query()
139            .map(|v| {
140                url::form_urlencoded::parse(v.as_bytes())
141                    .into_owned()
142                    .collect()
143            })
144            .unwrap_or_else(std::collections::HashMap::new),
145    };
146
147    // Phase 1: Log request
148    log_request(request_id, &request);
149
150    // Built-in route: serve embedded Datastar JS bundle (requires --datastar flag)
151    if config.datastar && request.path == DATASTAR_JS_PATH {
152        let use_brotli = compression::accepts_brotli(&parts.headers);
153        let mut header_map = hyper::header::HeaderMap::new();
154        header_map.insert(
155            hyper::header::CONTENT_TYPE,
156            hyper::header::HeaderValue::from_static("application/javascript"),
157        );
158        header_map.insert(
159            hyper::header::CACHE_CONTROL,
160            hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
161        );
162        let body = if use_brotli {
163            header_map.insert(
164                hyper::header::CONTENT_ENCODING,
165                hyper::header::HeaderValue::from_static("br"),
166            );
167            header_map.insert(
168                hyper::header::VARY,
169                hyper::header::HeaderValue::from_static("accept-encoding"),
170            );
171            Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
172                .map_err(|never| match never {})
173                .boxed()
174        } else {
175            Full::new(Bytes::from_static(DATASTAR_JS))
176                .map_err(|never| match never {})
177                .boxed()
178        };
179        log_response(request_id, 200, &header_map, start_time);
180        let logging_body = LoggingBody::new(body, guard);
181        let mut response = hyper::Response::builder()
182            .status(200)
183            .body(logging_body.boxed())?;
184        *response.headers_mut() = header_map;
185        return Ok(response);
186    }
187
188    let sse_cancel_token = engine.sse_cancel_token.clone();
189    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
190
191    // Wait for both:
192    // 1. Special response (from .static or .reverse-proxy) - None if normal response
193    // 2. Body pipeline result (includes http.response metadata for normal responses)
194    let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
195        tokio::join!(async { meta_rx.await.ok() }, async {
196            bridged_body.await.map_err(|e| e.into())
197        });
198
199    let use_brotli = compression::accepts_brotli(&parts.headers);
200
201    // Check if we got a special response (.static or .reverse-proxy)
202    match special_response.as_ref().map(|r| &r.body_type) {
203        Some(ResponseBodyType::Normal) | None => {
204            // Normal response - use metadata from pipeline
205            build_normal_response(
206                body_result?,
207                use_brotli,
208                guard,
209                start_time,
210                sse_cancel_token,
211            )
212            .await
213        }
214        Some(ResponseBodyType::Static {
215            root,
216            path,
217            fallback,
218        }) => {
219            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
220            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
221            *static_req.method_mut() = parts.method.clone();
222            *static_req.headers_mut() = parts.headers.clone();
223
224            let res = if let Some(fallback) = fallback {
225                let fp = root.join(fallback);
226                ServeDir::new(root)
227                    .fallback(ServeFile::new(fp))
228                    .call(static_req)
229                    .await?
230            } else {
231                ServeDir::new(root).call(static_req).await?
232            };
233            let (res_parts, body) = res.into_parts();
234            log_response(
235                request_id,
236                res_parts.status.as_u16(),
237                &res_parts.headers,
238                start_time,
239            );
240
241            let bytes = body.collect().await?.to_bytes();
242            let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
243            let logging_body = LoggingBody::new(inner_body, guard);
244            let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
245            Ok(res)
246        }
247        Some(ResponseBodyType::ReverseProxy {
248            target_url,
249            headers,
250            preserve_host,
251            strip_prefix,
252            request_body,
253            query,
254        }) => {
255            let body = Full::new(Bytes::from(request_body.clone()));
256            let mut proxy_req = hyper::Request::new(body);
257
258            // Handle strip_prefix
259            let path = if let Some(prefix) = strip_prefix {
260                parts
261                    .uri
262                    .path()
263                    .strip_prefix(prefix)
264                    .unwrap_or(parts.uri.path())
265            } else {
266                parts.uri.path()
267            };
268
269            // Build target URI
270            let target_uri = {
271                let query_string = if let Some(custom_query) = query {
272                    // Use custom query - convert HashMap to query string
273                    url::form_urlencoded::Serializer::new(String::new())
274                        .extend_pairs(custom_query.iter())
275                        .finish()
276                } else if let Some(orig_query) = parts.uri.query() {
277                    // Use original query string
278                    orig_query.to_string()
279                } else {
280                    String::new()
281                };
282
283                if query_string.is_empty() {
284                    format!("{target_url}{path}")
285                } else {
286                    format!("{target_url}{path}?{query_string}")
287                }
288            };
289
290            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
291            *proxy_req.method_mut() = parts.method.clone();
292
293            // Copy original headers
294            let mut header_map = parts.headers.clone();
295
296            // Update Content-Length to match the new body
297            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
298                header_map.insert(
299                    hyper::header::CONTENT_LENGTH,
300                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
301                );
302            }
303
304            // Add custom headers
305            for (k, v) in headers {
306                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
307
308                match v {
309                    crate::response::HeaderValue::Single(s) => {
310                        let header_value = hyper::header::HeaderValue::from_str(s)?;
311                        header_map.insert(header_name, header_value);
312                    }
313                    crate::response::HeaderValue::Multiple(values) => {
314                        for value in values {
315                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
316                                header_map.append(header_name.clone(), header_value);
317                            }
318                        }
319                    }
320                }
321            }
322
323            // Handle preserve_host
324            if !preserve_host {
325                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
326                    if let Some(authority) = target_uri.authority() {
327                        header_map.insert(
328                            hyper::header::HOST,
329                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
330                        );
331                    }
332                }
333            }
334
335            *proxy_req.headers_mut() = header_map;
336
337            // Create a simple HTTP client and forward the request
338            let client =
339                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
340                    .build_http();
341
342            match client.request(proxy_req).await {
343                Ok(response) => {
344                    let (res_parts, body) = response.into_parts();
345                    log_response(
346                        request_id,
347                        res_parts.status.as_u16(),
348                        &res_parts.headers,
349                        start_time,
350                    );
351
352                    let inner_body = body.map_err(|e| e.into()).boxed();
353                    let logging_body = LoggingBody::new(inner_body, guard);
354                    let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
355                    Ok(res)
356                }
357                Err(_e) => {
358                    let empty_headers = hyper::header::HeaderMap::new();
359                    log_response(request_id, 502, &empty_headers, start_time);
360
361                    let inner_body = Full::new("Bad Gateway".into())
362                        .map_err(|never| match never {})
363                        .boxed();
364                    let logging_body = LoggingBody::new(inner_body, guard);
365                    let response = hyper::Response::builder()
366                        .status(502)
367                        .body(logging_body.boxed())?;
368                    Ok(response)
369                }
370            }
371        }
372    }
373}
374
375async fn build_normal_response(
376    pipeline_result: PipelineResult,
377    use_brotli: bool,
378    guard: RequestGuard,
379    start_time: Instant,
380    sse_cancel_token: CancellationToken,
381) -> HTTPResult {
382    let request_id = guard.request_id();
383    let (inferred_content_type, http_meta, body) = pipeline_result;
384    let status = match (http_meta.status, &body) {
385        (Some(s), _) => s,
386        (None, ResponseTransport::Empty) => 204,
387        (None, _) => 200,
388    };
389    let mut builder = hyper::Response::builder().status(status);
390    let mut header_map = hyper::header::HeaderMap::new();
391
392    // Content-type precedence:
393    // 1. Explicit in http.response headers
394    // 2. Pipeline metadata (from `to json`, etc.)
395    // 3. Inferred: record->json, binary->octet-stream, list/stream of records->ndjson, empty->None
396    // 4. Default: text/html
397    let content_type = http_meta
398        .headers
399        .get("content-type")
400        .or(http_meta.headers.get("Content-Type"))
401        .and_then(|hv| match hv {
402            crate::response::HeaderValue::Single(s) => Some(s.clone()),
403            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
404        })
405        .or(inferred_content_type)
406        .or_else(|| {
407            if matches!(body, ResponseTransport::Empty) {
408                None
409            } else {
410                Some("text/html; charset=utf-8".to_string())
411            }
412        });
413
414    if let Some(ref ct) = content_type {
415        header_map.insert(
416            hyper::header::CONTENT_TYPE,
417            hyper::header::HeaderValue::from_str(ct)?,
418        );
419    }
420
421    // Add compression headers if using brotli
422    if use_brotli {
423        header_map.insert(
424            hyper::header::CONTENT_ENCODING,
425            hyper::header::HeaderValue::from_static("br"),
426        );
427        header_map.insert(
428            hyper::header::VARY,
429            hyper::header::HeaderValue::from_static("accept-encoding"),
430        );
431    }
432
433    // Add SSE-required headers for event streams
434    let is_sse = content_type.as_deref() == Some("text/event-stream");
435    if is_sse {
436        header_map.insert(
437            hyper::header::CACHE_CONTROL,
438            hyper::header::HeaderValue::from_static("no-cache"),
439        );
440        header_map.insert(
441            hyper::header::CONNECTION,
442            hyper::header::HeaderValue::from_static("keep-alive"),
443        );
444    }
445
446    for (k, v) in &http_meta.headers {
447        if k.to_lowercase() != "content-type" {
448            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
449
450            match v {
451                crate::response::HeaderValue::Single(s) => {
452                    let header_value = hyper::header::HeaderValue::from_str(s)?;
453                    header_map.insert(header_name, header_value);
454                }
455                crate::response::HeaderValue::Multiple(values) => {
456                    for value in values {
457                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
458                            header_map.append(header_name.clone(), header_value);
459                        }
460                    }
461                }
462            }
463        }
464    }
465
466    log_response(request_id, status, &header_map, start_time);
467    *builder.headers_mut().unwrap() = header_map;
468
469    let inner_body = match body {
470        ResponseTransport::Empty => Empty::<Bytes>::new()
471            .map_err(|never| match never {})
472            .boxed(),
473        ResponseTransport::Full(bytes) => {
474            if use_brotli {
475                let compressed = compression::compress_full(&bytes)?;
476                Full::new(Bytes::from(compressed))
477                    .map_err(|never| match never {})
478                    .boxed()
479            } else {
480                Full::new(bytes.into())
481                    .map_err(|never| match never {})
482                    .boxed()
483            }
484        }
485        ResponseTransport::Stream(rx) => {
486            // Layer 1: base byte stream, with reload-abort for SSE
487            let byte_stream: Pin<Box<dyn Stream<Item = Vec<u8>> + Send + Sync>> = if is_sse {
488                // SSE streams abort on cancellation (reload or shutdown)
489                Box::pin(futures_util::stream::unfold(
490                    (ReceiverStream::new(rx), sse_cancel_token),
491                    |(mut data_rx, token)| async move {
492                        tokio::select! {
493                            biased;
494                            _ = token.cancelled() => None,
495                            item = StreamExt::next(&mut data_rx) => {
496                                item.map(|data| (data, (data_rx, token)))
497                            }
498                        }
499                    },
500                ))
501            } else {
502                Box::pin(ReceiverStream::new(rx))
503            };
504
505            // Layer 2: optionally compress, then frame
506            if use_brotli {
507                let brotli = compression::BrotliStream::new(byte_stream);
508                BodyExt::boxed(StreamBody::new(brotli))
509            } else {
510                let stream = byte_stream.map(|data| Ok(Frame::data(Bytes::from(data))));
511                BodyExt::boxed(StreamBody::new(stream))
512            }
513        }
514    };
515
516    // Wrap with LoggingBody for phase 3 (complete) logging
517    let logging_body = LoggingBody::new(inner_body, guard);
518    Ok(builder.body(logging_body.boxed())?)
519}