Skip to main content

http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::time::Instant;
5
6use arc_swap::ArcSwap;
7use futures_util::{Stream, StreamExt};
8use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
9use hyper::body::{Bytes, Frame};
10use tokio_stream::wrappers::ReceiverStream;
11use tokio_util::sync::CancellationToken;
12use tower::Service;
13use tower_http::services::{ServeDir, ServeFile};
14
15use crate::compression;
16use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
17use crate::request::{resolve_trusted_ip, Request};
18use crate::response::{Response, ResponseBodyType, ResponseTransport};
19use crate::worker::{spawn_eval_thread, PipelineResult};
20
21type BoxError = Box<dyn std::error::Error + Send + Sync>;
22type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
23
24const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.8.js";
25const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js");
26const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.8.js.br");
27
28pub struct AppConfig {
29    pub trusted_proxies: Vec<ipnet::IpNet>,
30    pub datastar: bool,
31    pub dev: bool,
32}
33
34pub async fn handle<B>(
35    engine: Arc<ArcSwap<crate::Engine>>,
36    addr: Option<SocketAddr>,
37    config: Arc<AppConfig>,
38    req: hyper::Request<B>,
39) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
40where
41    B: hyper::body::Body + Unpin + Send + 'static,
42    B::Data: Into<Bytes> + Clone + Send,
43    B::Error: Into<BoxError> + Send,
44{
45    // Load current engine snapshot - lock-free atomic operation
46    let engine = engine.load_full();
47    match handle_inner(engine, addr, config, req).await {
48        Ok(response) => Ok(response),
49        Err(err) => {
50            eprintln!("Error handling request: {err}");
51            let response = hyper::Response::builder().status(500).body(
52                Full::new(format!("Script error: {err}").into())
53                    .map_err(|never| match never {})
54                    .boxed(),
55            )?;
56            Ok(response)
57        }
58    }
59}
60
61async fn handle_inner<B>(
62    engine: Arc<crate::Engine>,
63    addr: Option<SocketAddr>,
64    config: Arc<AppConfig>,
65    req: hyper::Request<B>,
66) -> HTTPResult
67where
68    B: hyper::body::Body + Unpin + Send + 'static,
69    B::Data: Into<Bytes> + Clone + Send,
70    B::Error: Into<BoxError> + Send,
71{
72    let (parts, mut body) = req.into_parts();
73
74    // Create channels for request body streaming
75    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
76
77    // Spawn task to read request body frames
78    tokio::task::spawn(async move {
79        while let Some(frame) = body.frame().await {
80            match frame {
81                Ok(frame) => {
82                    if let Some(data) = frame.data_ref() {
83                        let bytes: Bytes = (*data).clone().into();
84                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
85                            break;
86                        }
87                    }
88                }
89                Err(err) => {
90                    let _ = body_tx.send(Err(err.into())).await;
91                    break;
92                }
93            }
94        }
95    });
96
97    // Create ByteStream for Nu pipeline
98    let stream = nu_protocol::ByteStream::from_fn(
99        nu_protocol::Span::unknown(),
100        engine.state.signals().clone(),
101        nu_protocol::ByteStreamType::Unknown,
102        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
103            Some(Ok(bytes)) => {
104                buffer.extend_from_slice(&bytes);
105                Ok(true)
106            }
107            Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
108                error: "Body read error".into(),
109                msg: err.to_string(),
110                span: None,
111                help: None,
112                inner: vec![],
113            }),
114            None => Ok(false),
115        },
116    );
117
118    // Generate request ID and guard for logging
119    let start_time = Instant::now();
120    let request_id = scru128::new();
121    let guard = RequestGuard::new(request_id);
122
123    let remote_ip = addr.as_ref().map(|a| a.ip());
124    let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
125
126    let request = Request {
127        proto: format!("{:?}", parts.version),
128        method: parts.method.clone(),
129        authority: parts.uri.authority().map(|a| a.to_string()),
130        remote_ip,
131        remote_port: addr.as_ref().map(|a| a.port()),
132        trusted_ip,
133        headers: parts.headers.clone(),
134        uri: parts.uri.clone(),
135        path: parts.uri.path().to_string(),
136        query: parts
137            .uri
138            .query()
139            .map(|v| {
140                url::form_urlencoded::parse(v.as_bytes())
141                    .into_owned()
142                    .collect()
143            })
144            .unwrap_or_else(std::collections::HashMap::new),
145    };
146
147    // Phase 1: Log request
148    log_request(request_id, &request);
149
150    // Built-in route: serve embedded Datastar JS bundle (requires --datastar flag)
151    if config.datastar && request.path == DATASTAR_JS_PATH {
152        let use_brotli = compression::accepts_brotli(&parts.headers);
153        let mut header_map = hyper::header::HeaderMap::new();
154        header_map.insert(
155            hyper::header::CONTENT_TYPE,
156            hyper::header::HeaderValue::from_static("application/javascript"),
157        );
158        header_map.insert(
159            hyper::header::CACHE_CONTROL,
160            hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
161        );
162        let body = if use_brotli {
163            header_map.insert(
164                hyper::header::CONTENT_ENCODING,
165                hyper::header::HeaderValue::from_static("br"),
166            );
167            header_map.insert(
168                hyper::header::VARY,
169                hyper::header::HeaderValue::from_static("accept-encoding"),
170            );
171            Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
172                .map_err(|never| match never {})
173                .boxed()
174        } else {
175            Full::new(Bytes::from_static(DATASTAR_JS))
176                .map_err(|never| match never {})
177                .boxed()
178        };
179        log_response(request_id, 200, &header_map, start_time);
180        let logging_body = LoggingBody::new(body, guard);
181        let mut response = hyper::Response::builder()
182            .status(200)
183            .body(logging_body.boxed())?;
184        *response.headers_mut() = header_map;
185        return Ok(response);
186    }
187
188    let reload_token = engine.reload_token.clone();
189    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
190
191    // Wait for both:
192    // 1. Special response (from .static or .reverse-proxy) - None if normal response
193    // 2. Body pipeline result (includes http.response metadata for normal responses)
194    let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
195        tokio::join!(async { meta_rx.await.ok() }, async {
196            bridged_body.await.map_err(|e| e.into())
197        });
198
199    let use_brotli = compression::accepts_brotli(&parts.headers);
200
201    // Check if we got a special response (.static or .reverse-proxy)
202    match special_response.as_ref().map(|r| &r.body_type) {
203        Some(ResponseBodyType::Normal) | None => {
204            // Normal response - use metadata from pipeline
205            build_normal_response(body_result?, use_brotli, guard, start_time, reload_token).await
206        }
207        Some(ResponseBodyType::Static {
208            root,
209            path,
210            fallback,
211        }) => {
212            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
213            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
214            *static_req.method_mut() = parts.method.clone();
215            *static_req.headers_mut() = parts.headers.clone();
216
217            let res = if let Some(fallback) = fallback {
218                let fp = root.join(fallback);
219                ServeDir::new(root)
220                    .fallback(ServeFile::new(fp))
221                    .call(static_req)
222                    .await?
223            } else {
224                ServeDir::new(root).call(static_req).await?
225            };
226            let (res_parts, body) = res.into_parts();
227            log_response(
228                request_id,
229                res_parts.status.as_u16(),
230                &res_parts.headers,
231                start_time,
232            );
233
234            let bytes = body.collect().await?.to_bytes();
235            let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
236            let logging_body = LoggingBody::new(inner_body, guard);
237            let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
238            Ok(res)
239        }
240        Some(ResponseBodyType::ReverseProxy {
241            target_url,
242            headers,
243            preserve_host,
244            strip_prefix,
245            request_body,
246            query,
247        }) => {
248            let body = Full::new(Bytes::from(request_body.clone()));
249            let mut proxy_req = hyper::Request::new(body);
250
251            // Handle strip_prefix
252            let path = if let Some(prefix) = strip_prefix {
253                parts
254                    .uri
255                    .path()
256                    .strip_prefix(prefix)
257                    .unwrap_or(parts.uri.path())
258            } else {
259                parts.uri.path()
260            };
261
262            // Build target URI
263            let target_uri = {
264                let query_string = if let Some(custom_query) = query {
265                    // Use custom query - convert HashMap to query string
266                    url::form_urlencoded::Serializer::new(String::new())
267                        .extend_pairs(custom_query.iter())
268                        .finish()
269                } else if let Some(orig_query) = parts.uri.query() {
270                    // Use original query string
271                    orig_query.to_string()
272                } else {
273                    String::new()
274                };
275
276                if query_string.is_empty() {
277                    format!("{target_url}{path}")
278                } else {
279                    format!("{target_url}{path}?{query_string}")
280                }
281            };
282
283            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
284            *proxy_req.method_mut() = parts.method.clone();
285
286            // Copy original headers
287            let mut header_map = parts.headers.clone();
288
289            // Update Content-Length to match the new body
290            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
291                header_map.insert(
292                    hyper::header::CONTENT_LENGTH,
293                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
294                );
295            }
296
297            // Add custom headers
298            for (k, v) in headers {
299                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
300
301                match v {
302                    crate::response::HeaderValue::Single(s) => {
303                        let header_value = hyper::header::HeaderValue::from_str(s)?;
304                        header_map.insert(header_name, header_value);
305                    }
306                    crate::response::HeaderValue::Multiple(values) => {
307                        for value in values {
308                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
309                                header_map.append(header_name.clone(), header_value);
310                            }
311                        }
312                    }
313                }
314            }
315
316            // Handle preserve_host
317            if !preserve_host {
318                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
319                    if let Some(authority) = target_uri.authority() {
320                        header_map.insert(
321                            hyper::header::HOST,
322                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
323                        );
324                    }
325                }
326            }
327
328            *proxy_req.headers_mut() = header_map;
329
330            // Create a simple HTTP client and forward the request
331            let client =
332                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
333                    .build_http();
334
335            match client.request(proxy_req).await {
336                Ok(response) => {
337                    let (res_parts, body) = response.into_parts();
338                    log_response(
339                        request_id,
340                        res_parts.status.as_u16(),
341                        &res_parts.headers,
342                        start_time,
343                    );
344
345                    let inner_body = body.map_err(|e| e.into()).boxed();
346                    let logging_body = LoggingBody::new(inner_body, guard);
347                    let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
348                    Ok(res)
349                }
350                Err(_e) => {
351                    let empty_headers = hyper::header::HeaderMap::new();
352                    log_response(request_id, 502, &empty_headers, start_time);
353
354                    let inner_body = Full::new("Bad Gateway".into())
355                        .map_err(|never| match never {})
356                        .boxed();
357                    let logging_body = LoggingBody::new(inner_body, guard);
358                    let response = hyper::Response::builder()
359                        .status(502)
360                        .body(logging_body.boxed())?;
361                    Ok(response)
362                }
363            }
364        }
365    }
366}
367
368async fn build_normal_response(
369    pipeline_result: PipelineResult,
370    use_brotli: bool,
371    guard: RequestGuard,
372    start_time: Instant,
373    reload_token: CancellationToken,
374) -> HTTPResult {
375    let request_id = guard.request_id();
376    let (inferred_content_type, http_meta, body) = pipeline_result;
377    let status = match (http_meta.status, &body) {
378        (Some(s), _) => s,
379        (None, ResponseTransport::Empty) => 204,
380        (None, _) => 200,
381    };
382    let mut builder = hyper::Response::builder().status(status);
383    let mut header_map = hyper::header::HeaderMap::new();
384
385    // Content-type precedence:
386    // 1. Explicit in http.response headers
387    // 2. Pipeline metadata (from `to json`, etc.)
388    // 3. Inferred: record->json, binary->octet-stream, list/stream of records->ndjson, empty->None
389    // 4. Default: text/html
390    let content_type = http_meta
391        .headers
392        .get("content-type")
393        .or(http_meta.headers.get("Content-Type"))
394        .and_then(|hv| match hv {
395            crate::response::HeaderValue::Single(s) => Some(s.clone()),
396            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
397        })
398        .or(inferred_content_type)
399        .or_else(|| {
400            if matches!(body, ResponseTransport::Empty) {
401                None
402            } else {
403                Some("text/html; charset=utf-8".to_string())
404            }
405        });
406
407    if let Some(ref ct) = content_type {
408        header_map.insert(
409            hyper::header::CONTENT_TYPE,
410            hyper::header::HeaderValue::from_str(ct)?,
411        );
412    }
413
414    // Add compression headers if using brotli
415    if use_brotli {
416        header_map.insert(
417            hyper::header::CONTENT_ENCODING,
418            hyper::header::HeaderValue::from_static("br"),
419        );
420        header_map.insert(
421            hyper::header::VARY,
422            hyper::header::HeaderValue::from_static("accept-encoding"),
423        );
424    }
425
426    // Add SSE-required headers for event streams
427    let is_sse = content_type.as_deref() == Some("text/event-stream");
428    if is_sse {
429        header_map.insert(
430            hyper::header::CACHE_CONTROL,
431            hyper::header::HeaderValue::from_static("no-cache"),
432        );
433        header_map.insert(
434            hyper::header::CONNECTION,
435            hyper::header::HeaderValue::from_static("keep-alive"),
436        );
437    }
438
439    for (k, v) in &http_meta.headers {
440        if k.to_lowercase() != "content-type" {
441            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
442
443            match v {
444                crate::response::HeaderValue::Single(s) => {
445                    let header_value = hyper::header::HeaderValue::from_str(s)?;
446                    header_map.insert(header_name, header_value);
447                }
448                crate::response::HeaderValue::Multiple(values) => {
449                    for value in values {
450                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
451                            header_map.append(header_name.clone(), header_value);
452                        }
453                    }
454                }
455            }
456        }
457    }
458
459    log_response(request_id, status, &header_map, start_time);
460    *builder.headers_mut().unwrap() = header_map;
461
462    let inner_body = match body {
463        ResponseTransport::Empty => Empty::<Bytes>::new()
464            .map_err(|never| match never {})
465            .boxed(),
466        ResponseTransport::Full(bytes) => {
467            if use_brotli {
468                let compressed = compression::compress_full(&bytes)?;
469                Full::new(Bytes::from(compressed))
470                    .map_err(|never| match never {})
471                    .boxed()
472            } else {
473                Full::new(bytes.into())
474                    .map_err(|never| match never {})
475                    .boxed()
476            }
477        }
478        ResponseTransport::Stream(rx) => {
479            // Layer 1: base byte stream, with reload-abort for SSE
480            let byte_stream: Pin<Box<dyn Stream<Item = Vec<u8>> + Send + Sync>> = if is_sse {
481                // SSE streams abort on reload (error triggers client retry)
482                Box::pin(futures_util::stream::unfold(
483                    (ReceiverStream::new(rx), reload_token),
484                    |(mut data_rx, token)| async move {
485                        tokio::select! {
486                            biased;
487                            _ = token.cancelled() => None,
488                            item = StreamExt::next(&mut data_rx) => {
489                                item.map(|data| (data, (data_rx, token)))
490                            }
491                        }
492                    },
493                ))
494            } else {
495                Box::pin(ReceiverStream::new(rx))
496            };
497
498            // Layer 2: optionally compress, then frame
499            if use_brotli {
500                let brotli = compression::BrotliStream::new(byte_stream);
501                BodyExt::boxed(StreamBody::new(brotli))
502            } else {
503                let stream = byte_stream.map(|data| Ok(Frame::data(Bytes::from(data))));
504                BodyExt::boxed(StreamBody::new(stream))
505            }
506        }
507    };
508
509    // Wrap with LoggingBody for phase 3 (complete) logging
510    let logging_body = LoggingBody::new(inner_body, guard);
511    Ok(builder.body(logging_body.boxed())?)
512}