Skip to main content

http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use futures_util::StreamExt;
7use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
8use hyper::body::{Bytes, Frame};
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tower::Service;
12use tower_http::services::{ServeDir, ServeFile};
13
14use crate::compression;
15use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
16use crate::request::{resolve_trusted_ip, Request};
17use crate::response::{Response, ResponseBodyType, ResponseTransport};
18use crate::worker::{spawn_eval_thread, PipelineResult};
19
20type BoxError = Box<dyn std::error::Error + Send + Sync>;
21type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
22
23const DATASTAR_JS_PATH: &str = "/datastar@1.0.0-RC.7.js";
24const DATASTAR_JS: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.7.js");
25const DATASTAR_JS_BROTLI: &[u8] = include_bytes!("stdlib/datastar/datastar@1.0.0-RC.7.js.br");
26
27pub struct AppConfig {
28    pub trusted_proxies: Vec<ipnet::IpNet>,
29    pub datastar: bool,
30    pub dev: bool,
31}
32
33pub async fn handle<B>(
34    engine: Arc<ArcSwap<crate::Engine>>,
35    addr: Option<SocketAddr>,
36    config: Arc<AppConfig>,
37    req: hyper::Request<B>,
38) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
39where
40    B: hyper::body::Body + Unpin + Send + 'static,
41    B::Data: Into<Bytes> + Clone + Send,
42    B::Error: Into<BoxError> + Send,
43{
44    // Load current engine snapshot - lock-free atomic operation
45    let engine = engine.load_full();
46    match handle_inner(engine, addr, config, req).await {
47        Ok(response) => Ok(response),
48        Err(err) => {
49            eprintln!("Error handling request: {err}");
50            let response = hyper::Response::builder().status(500).body(
51                Full::new(format!("Script error: {err}").into())
52                    .map_err(|never| match never {})
53                    .boxed(),
54            )?;
55            Ok(response)
56        }
57    }
58}
59
60async fn handle_inner<B>(
61    engine: Arc<crate::Engine>,
62    addr: Option<SocketAddr>,
63    config: Arc<AppConfig>,
64    req: hyper::Request<B>,
65) -> HTTPResult
66where
67    B: hyper::body::Body + Unpin + Send + 'static,
68    B::Data: Into<Bytes> + Clone + Send,
69    B::Error: Into<BoxError> + Send,
70{
71    let (parts, mut body) = req.into_parts();
72
73    // Create channels for request body streaming
74    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
75
76    // Spawn task to read request body frames
77    tokio::task::spawn(async move {
78        while let Some(frame) = body.frame().await {
79            match frame {
80                Ok(frame) => {
81                    if let Some(data) = frame.data_ref() {
82                        let bytes: Bytes = (*data).clone().into();
83                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
84                            break;
85                        }
86                    }
87                }
88                Err(err) => {
89                    let _ = body_tx.send(Err(err.into())).await;
90                    break;
91                }
92            }
93        }
94    });
95
96    // Create ByteStream for Nu pipeline
97    let stream = nu_protocol::ByteStream::from_fn(
98        nu_protocol::Span::unknown(),
99        engine.state.signals().clone(),
100        nu_protocol::ByteStreamType::Unknown,
101        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
102            Some(Ok(bytes)) => {
103                buffer.extend_from_slice(&bytes);
104                Ok(true)
105            }
106            Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
107                error: "Body read error".into(),
108                msg: err.to_string(),
109                span: None,
110                help: None,
111                inner: vec![],
112            }),
113            None => Ok(false),
114        },
115    );
116
117    // Generate request ID and guard for logging
118    let start_time = Instant::now();
119    let request_id = scru128::new();
120    let guard = RequestGuard::new(request_id);
121
122    let remote_ip = addr.as_ref().map(|a| a.ip());
123    let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &config.trusted_proxies);
124
125    let request = Request {
126        proto: format!("{:?}", parts.version),
127        method: parts.method.clone(),
128        authority: parts.uri.authority().map(|a| a.to_string()),
129        remote_ip,
130        remote_port: addr.as_ref().map(|a| a.port()),
131        trusted_ip,
132        headers: parts.headers.clone(),
133        uri: parts.uri.clone(),
134        path: parts.uri.path().to_string(),
135        query: parts
136            .uri
137            .query()
138            .map(|v| {
139                url::form_urlencoded::parse(v.as_bytes())
140                    .into_owned()
141                    .collect()
142            })
143            .unwrap_or_else(std::collections::HashMap::new),
144    };
145
146    // Phase 1: Log request
147    log_request(request_id, &request);
148
149    // Built-in route: serve embedded Datastar JS bundle (requires --datastar flag)
150    if config.datastar && request.path == DATASTAR_JS_PATH {
151        let use_brotli = compression::accepts_brotli(&parts.headers);
152        let mut header_map = hyper::header::HeaderMap::new();
153        header_map.insert(
154            hyper::header::CONTENT_TYPE,
155            hyper::header::HeaderValue::from_static("application/javascript"),
156        );
157        header_map.insert(
158            hyper::header::CACHE_CONTROL,
159            hyper::header::HeaderValue::from_static("public, max-age=31536000, immutable"),
160        );
161        let body = if use_brotli {
162            header_map.insert(
163                hyper::header::CONTENT_ENCODING,
164                hyper::header::HeaderValue::from_static("br"),
165            );
166            header_map.insert(
167                hyper::header::VARY,
168                hyper::header::HeaderValue::from_static("accept-encoding"),
169            );
170            Full::new(Bytes::from_static(DATASTAR_JS_BROTLI))
171                .map_err(|never| match never {})
172                .boxed()
173        } else {
174            Full::new(Bytes::from_static(DATASTAR_JS))
175                .map_err(|never| match never {})
176                .boxed()
177        };
178        log_response(request_id, 200, &header_map, start_time);
179        let logging_body = LoggingBody::new(body, guard);
180        let mut response = hyper::Response::builder()
181            .status(200)
182            .body(logging_body.boxed())?;
183        *response.headers_mut() = header_map;
184        return Ok(response);
185    }
186
187    let reload_token = engine.reload_token.clone();
188    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
189
190    // Wait for both:
191    // 1. Special response (from .static or .reverse-proxy) - None if normal response
192    // 2. Body pipeline result (includes http.response metadata for normal responses)
193    let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
194        tokio::join!(async { meta_rx.await.ok() }, async {
195            bridged_body.await.map_err(|e| e.into())
196        });
197
198    let use_brotli = compression::accepts_brotli(&parts.headers);
199
200    // Check if we got a special response (.static or .reverse-proxy)
201    match special_response.as_ref().map(|r| &r.body_type) {
202        Some(ResponseBodyType::Normal) | None => {
203            // Normal response - use metadata from pipeline
204            build_normal_response(body_result?, use_brotli, guard, start_time, reload_token).await
205        }
206        Some(ResponseBodyType::Static {
207            root,
208            path,
209            fallback,
210        }) => {
211            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
212            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
213            *static_req.method_mut() = parts.method.clone();
214            *static_req.headers_mut() = parts.headers.clone();
215
216            let res = if let Some(fallback) = fallback {
217                let fp = root.join(fallback);
218                ServeDir::new(root)
219                    .fallback(ServeFile::new(fp))
220                    .call(static_req)
221                    .await?
222            } else {
223                ServeDir::new(root).call(static_req).await?
224            };
225            let (res_parts, body) = res.into_parts();
226            log_response(
227                request_id,
228                res_parts.status.as_u16(),
229                &res_parts.headers,
230                start_time,
231            );
232
233            let bytes = body.collect().await?.to_bytes();
234            let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
235            let logging_body = LoggingBody::new(inner_body, guard);
236            let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
237            Ok(res)
238        }
239        Some(ResponseBodyType::ReverseProxy {
240            target_url,
241            headers,
242            preserve_host,
243            strip_prefix,
244            request_body,
245            query,
246        }) => {
247            let body = Full::new(Bytes::from(request_body.clone()));
248            let mut proxy_req = hyper::Request::new(body);
249
250            // Handle strip_prefix
251            let path = if let Some(prefix) = strip_prefix {
252                parts
253                    .uri
254                    .path()
255                    .strip_prefix(prefix)
256                    .unwrap_or(parts.uri.path())
257            } else {
258                parts.uri.path()
259            };
260
261            // Build target URI
262            let target_uri = {
263                let query_string = if let Some(custom_query) = query {
264                    // Use custom query - convert HashMap to query string
265                    url::form_urlencoded::Serializer::new(String::new())
266                        .extend_pairs(custom_query.iter())
267                        .finish()
268                } else if let Some(orig_query) = parts.uri.query() {
269                    // Use original query string
270                    orig_query.to_string()
271                } else {
272                    String::new()
273                };
274
275                if query_string.is_empty() {
276                    format!("{target_url}{path}")
277                } else {
278                    format!("{target_url}{path}?{query_string}")
279                }
280            };
281
282            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
283            *proxy_req.method_mut() = parts.method.clone();
284
285            // Copy original headers
286            let mut header_map = parts.headers.clone();
287
288            // Update Content-Length to match the new body
289            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
290                header_map.insert(
291                    hyper::header::CONTENT_LENGTH,
292                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
293                );
294            }
295
296            // Add custom headers
297            for (k, v) in headers {
298                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
299
300                match v {
301                    crate::response::HeaderValue::Single(s) => {
302                        let header_value = hyper::header::HeaderValue::from_str(s)?;
303                        header_map.insert(header_name, header_value);
304                    }
305                    crate::response::HeaderValue::Multiple(values) => {
306                        for value in values {
307                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
308                                header_map.append(header_name.clone(), header_value);
309                            }
310                        }
311                    }
312                }
313            }
314
315            // Handle preserve_host
316            if !preserve_host {
317                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
318                    if let Some(authority) = target_uri.authority() {
319                        header_map.insert(
320                            hyper::header::HOST,
321                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
322                        );
323                    }
324                }
325            }
326
327            *proxy_req.headers_mut() = header_map;
328
329            // Create a simple HTTP client and forward the request
330            let client =
331                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
332                    .build_http();
333
334            match client.request(proxy_req).await {
335                Ok(response) => {
336                    let (res_parts, body) = response.into_parts();
337                    log_response(
338                        request_id,
339                        res_parts.status.as_u16(),
340                        &res_parts.headers,
341                        start_time,
342                    );
343
344                    let inner_body = body.map_err(|e| e.into()).boxed();
345                    let logging_body = LoggingBody::new(inner_body, guard);
346                    let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
347                    Ok(res)
348                }
349                Err(_e) => {
350                    let empty_headers = hyper::header::HeaderMap::new();
351                    log_response(request_id, 502, &empty_headers, start_time);
352
353                    let inner_body = Full::new("Bad Gateway".into())
354                        .map_err(|never| match never {})
355                        .boxed();
356                    let logging_body = LoggingBody::new(inner_body, guard);
357                    let response = hyper::Response::builder()
358                        .status(502)
359                        .body(logging_body.boxed())?;
360                    Ok(response)
361                }
362            }
363        }
364    }
365}
366
367async fn build_normal_response(
368    pipeline_result: PipelineResult,
369    use_brotli: bool,
370    guard: RequestGuard,
371    start_time: Instant,
372    reload_token: CancellationToken,
373) -> HTTPResult {
374    let request_id = guard.request_id();
375    let (inferred_content_type, http_meta, body) = pipeline_result;
376    let status = match (http_meta.status, &body) {
377        (Some(s), _) => s,
378        (None, ResponseTransport::Empty) => 204,
379        (None, _) => 200,
380    };
381    let mut builder = hyper::Response::builder().status(status);
382    let mut header_map = hyper::header::HeaderMap::new();
383
384    // Content-type precedence:
385    // 1. Explicit in http.response headers
386    // 2. Pipeline metadata (from `to json`, etc.)
387    // 3. Inferred: record->json, binary->octet-stream, list/stream of records->ndjson, empty->None
388    // 4. Default: text/html
389    let content_type = http_meta
390        .headers
391        .get("content-type")
392        .or(http_meta.headers.get("Content-Type"))
393        .and_then(|hv| match hv {
394            crate::response::HeaderValue::Single(s) => Some(s.clone()),
395            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
396        })
397        .or(inferred_content_type)
398        .or_else(|| {
399            if matches!(body, ResponseTransport::Empty) {
400                None
401            } else {
402                Some("text/html; charset=utf-8".to_string())
403            }
404        });
405
406    if let Some(ref ct) = content_type {
407        header_map.insert(
408            hyper::header::CONTENT_TYPE,
409            hyper::header::HeaderValue::from_str(ct)?,
410        );
411    }
412
413    // Add compression headers if using brotli
414    if use_brotli {
415        header_map.insert(
416            hyper::header::CONTENT_ENCODING,
417            hyper::header::HeaderValue::from_static("br"),
418        );
419        header_map.insert(
420            hyper::header::VARY,
421            hyper::header::HeaderValue::from_static("accept-encoding"),
422        );
423    }
424
425    // Add SSE-required headers for event streams
426    let is_sse = content_type.as_deref() == Some("text/event-stream");
427    if is_sse {
428        header_map.insert(
429            hyper::header::CACHE_CONTROL,
430            hyper::header::HeaderValue::from_static("no-cache"),
431        );
432        header_map.insert(
433            hyper::header::CONNECTION,
434            hyper::header::HeaderValue::from_static("keep-alive"),
435        );
436    }
437
438    for (k, v) in &http_meta.headers {
439        if k.to_lowercase() != "content-type" {
440            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
441
442            match v {
443                crate::response::HeaderValue::Single(s) => {
444                    let header_value = hyper::header::HeaderValue::from_str(s)?;
445                    header_map.insert(header_name, header_value);
446                }
447                crate::response::HeaderValue::Multiple(values) => {
448                    for value in values {
449                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
450                            header_map.append(header_name.clone(), header_value);
451                        }
452                    }
453                }
454            }
455        }
456    }
457
458    log_response(request_id, status, &header_map, start_time);
459    *builder.headers_mut().unwrap() = header_map;
460
461    let inner_body = match body {
462        ResponseTransport::Empty => Empty::<Bytes>::new()
463            .map_err(|never| match never {})
464            .boxed(),
465        ResponseTransport::Full(bytes) => {
466            if use_brotli {
467                let compressed = compression::compress_full(&bytes)?;
468                Full::new(Bytes::from(compressed))
469                    .map_err(|never| match never {})
470                    .boxed()
471            } else {
472                Full::new(bytes.into())
473                    .map_err(|never| match never {})
474                    .boxed()
475            }
476        }
477        ResponseTransport::Stream(rx) => {
478            if use_brotli {
479                compression::compress_stream(rx)
480            } else if is_sse {
481                // SSE streams abort on reload (error triggers client retry)
482                let stream = futures_util::stream::try_unfold(
483                    (ReceiverStream::new(rx), reload_token),
484                    |(mut data_rx, token)| async move {
485                        tokio::select! {
486                            biased;
487                            _ = token.cancelled() => {
488                                Err(std::io::Error::other("reload").into())
489                            }
490                            item = StreamExt::next(&mut data_rx) => {
491                                match item {
492                                    Some(data) => Ok(Some((Frame::data(Bytes::from(data)), (data_rx, token)))),
493                                    None => Ok(None),
494                                }
495                            }
496                        }
497                    },
498                );
499                BodyExt::boxed(StreamBody::new(stream))
500            } else {
501                let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
502                BodyExt::boxed(StreamBody::new(stream))
503            }
504        }
505    };
506
507    // Wrap with LoggingBody for phase 3 (complete) logging
508    let logging_body = LoggingBody::new(inner_body, guard);
509    Ok(builder.body(logging_body.boxed())?)
510}