Skip to main content

http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use futures_util::StreamExt;
7use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
8use hyper::body::{Bytes, Frame};
9use tokio_stream::wrappers::ReceiverStream;
10use tokio_util::sync::CancellationToken;
11use tower::Service;
12use tower_http::services::{ServeDir, ServeFile};
13
14use crate::compression;
15use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
16use crate::request::{resolve_trusted_ip, Request};
17use crate::response::{Response, ResponseBodyType, ResponseTransport};
18use crate::worker::{spawn_eval_thread, PipelineResult};
19
20type BoxError = Box<dyn std::error::Error + Send + Sync>;
21type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
22
23pub async fn handle<B>(
24    engine: Arc<ArcSwap<crate::Engine>>,
25    addr: Option<SocketAddr>,
26    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
27    req: hyper::Request<B>,
28) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
29where
30    B: hyper::body::Body + Unpin + Send + 'static,
31    B::Data: Into<Bytes> + Clone + Send,
32    B::Error: Into<BoxError> + Send,
33{
34    // Load current engine snapshot - lock-free atomic operation
35    let engine = engine.load_full();
36    match handle_inner(engine, addr, trusted_proxies, req).await {
37        Ok(response) => Ok(response),
38        Err(err) => {
39            eprintln!("Error handling request: {err}");
40            let response = hyper::Response::builder().status(500).body(
41                Full::new(format!("Script error: {err}").into())
42                    .map_err(|never| match never {})
43                    .boxed(),
44            )?;
45            Ok(response)
46        }
47    }
48}
49
50async fn handle_inner<B>(
51    engine: Arc<crate::Engine>,
52    addr: Option<SocketAddr>,
53    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
54    req: hyper::Request<B>,
55) -> HTTPResult
56where
57    B: hyper::body::Body + Unpin + Send + 'static,
58    B::Data: Into<Bytes> + Clone + Send,
59    B::Error: Into<BoxError> + Send,
60{
61    let (parts, mut body) = req.into_parts();
62
63    // Create channels for request body streaming
64    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
65
66    // Spawn task to read request body frames
67    tokio::task::spawn(async move {
68        while let Some(frame) = body.frame().await {
69            match frame {
70                Ok(frame) => {
71                    if let Some(data) = frame.data_ref() {
72                        let bytes: Bytes = (*data).clone().into();
73                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
74                            break;
75                        }
76                    }
77                }
78                Err(err) => {
79                    let _ = body_tx.send(Err(err.into())).await;
80                    break;
81                }
82            }
83        }
84    });
85
86    // Create ByteStream for Nu pipeline
87    let stream = nu_protocol::ByteStream::from_fn(
88        nu_protocol::Span::unknown(),
89        engine.state.signals().clone(),
90        nu_protocol::ByteStreamType::Unknown,
91        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
92            Some(Ok(bytes)) => {
93                buffer.extend_from_slice(&bytes);
94                Ok(true)
95            }
96            Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
97                error: "Body read error".into(),
98                msg: err.to_string(),
99                span: None,
100                help: None,
101                inner: vec![],
102            }),
103            None => Ok(false),
104        },
105    );
106
107    // Generate request ID and guard for logging
108    let start_time = Instant::now();
109    let request_id = scru128::new();
110    let guard = RequestGuard::new(request_id);
111
112    let remote_ip = addr.as_ref().map(|a| a.ip());
113    let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &trusted_proxies);
114
115    let request = Request {
116        proto: format!("{:?}", parts.version),
117        method: parts.method.clone(),
118        authority: parts.uri.authority().map(|a| a.to_string()),
119        remote_ip,
120        remote_port: addr.as_ref().map(|a| a.port()),
121        trusted_ip,
122        headers: parts.headers.clone(),
123        uri: parts.uri.clone(),
124        path: parts.uri.path().to_string(),
125        query: parts
126            .uri
127            .query()
128            .map(|v| {
129                url::form_urlencoded::parse(v.as_bytes())
130                    .into_owned()
131                    .collect()
132            })
133            .unwrap_or_else(std::collections::HashMap::new),
134    };
135
136    // Phase 1: Log request
137    log_request(request_id, &request);
138
139    let reload_token = engine.reload_token.clone();
140    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
141
142    // Wait for both:
143    // 1. Special response (from .static or .reverse-proxy) - None if normal response
144    // 2. Body pipeline result (includes http.response metadata for normal responses)
145    let (special_response, body_result): (Option<Response>, Result<PipelineResult, BoxError>) =
146        tokio::join!(async { meta_rx.await.ok() }, async {
147            bridged_body.await.map_err(|e| e.into())
148        });
149
150    let use_brotli = compression::accepts_brotli(&parts.headers);
151
152    // Check if we got a special response (.static or .reverse-proxy)
153    match special_response.as_ref().map(|r| &r.body_type) {
154        Some(ResponseBodyType::Normal) | None => {
155            // Normal response - use metadata from pipeline
156            build_normal_response(body_result?, use_brotli, guard, start_time, reload_token).await
157        }
158        Some(ResponseBodyType::Static {
159            root,
160            path,
161            fallback,
162        }) => {
163            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
164            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
165            *static_req.method_mut() = parts.method.clone();
166            *static_req.headers_mut() = parts.headers.clone();
167
168            let res = if let Some(fallback) = fallback {
169                let fp = root.join(fallback);
170                ServeDir::new(root)
171                    .fallback(ServeFile::new(fp))
172                    .call(static_req)
173                    .await?
174            } else {
175                ServeDir::new(root).call(static_req).await?
176            };
177            let (res_parts, body) = res.into_parts();
178            log_response(
179                request_id,
180                res_parts.status.as_u16(),
181                &res_parts.headers,
182                start_time,
183            );
184
185            let bytes = body.collect().await?.to_bytes();
186            let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
187            let logging_body = LoggingBody::new(inner_body, guard);
188            let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
189            Ok(res)
190        }
191        Some(ResponseBodyType::ReverseProxy {
192            target_url,
193            headers,
194            preserve_host,
195            strip_prefix,
196            request_body,
197            query,
198        }) => {
199            let body = Full::new(Bytes::from(request_body.clone()));
200            let mut proxy_req = hyper::Request::new(body);
201
202            // Handle strip_prefix
203            let path = if let Some(prefix) = strip_prefix {
204                parts
205                    .uri
206                    .path()
207                    .strip_prefix(prefix)
208                    .unwrap_or(parts.uri.path())
209            } else {
210                parts.uri.path()
211            };
212
213            // Build target URI
214            let target_uri = {
215                let query_string = if let Some(custom_query) = query {
216                    // Use custom query - convert HashMap to query string
217                    url::form_urlencoded::Serializer::new(String::new())
218                        .extend_pairs(custom_query.iter())
219                        .finish()
220                } else if let Some(orig_query) = parts.uri.query() {
221                    // Use original query string
222                    orig_query.to_string()
223                } else {
224                    String::new()
225                };
226
227                if query_string.is_empty() {
228                    format!("{target_url}{path}")
229                } else {
230                    format!("{target_url}{path}?{query_string}")
231                }
232            };
233
234            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
235            *proxy_req.method_mut() = parts.method.clone();
236
237            // Copy original headers
238            let mut header_map = parts.headers.clone();
239
240            // Update Content-Length to match the new body
241            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
242                header_map.insert(
243                    hyper::header::CONTENT_LENGTH,
244                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
245                );
246            }
247
248            // Add custom headers
249            for (k, v) in headers {
250                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
251
252                match v {
253                    crate::response::HeaderValue::Single(s) => {
254                        let header_value = hyper::header::HeaderValue::from_str(s)?;
255                        header_map.insert(header_name, header_value);
256                    }
257                    crate::response::HeaderValue::Multiple(values) => {
258                        for value in values {
259                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
260                                header_map.append(header_name.clone(), header_value);
261                            }
262                        }
263                    }
264                }
265            }
266
267            // Handle preserve_host
268            if !preserve_host {
269                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
270                    if let Some(authority) = target_uri.authority() {
271                        header_map.insert(
272                            hyper::header::HOST,
273                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
274                        );
275                    }
276                }
277            }
278
279            *proxy_req.headers_mut() = header_map;
280
281            // Create a simple HTTP client and forward the request
282            let client =
283                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
284                    .build_http();
285
286            match client.request(proxy_req).await {
287                Ok(response) => {
288                    let (res_parts, body) = response.into_parts();
289                    log_response(
290                        request_id,
291                        res_parts.status.as_u16(),
292                        &res_parts.headers,
293                        start_time,
294                    );
295
296                    let inner_body = body.map_err(|e| e.into()).boxed();
297                    let logging_body = LoggingBody::new(inner_body, guard);
298                    let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
299                    Ok(res)
300                }
301                Err(_e) => {
302                    let empty_headers = hyper::header::HeaderMap::new();
303                    log_response(request_id, 502, &empty_headers, start_time);
304
305                    let inner_body = Full::new("Bad Gateway".into())
306                        .map_err(|never| match never {})
307                        .boxed();
308                    let logging_body = LoggingBody::new(inner_body, guard);
309                    let response = hyper::Response::builder()
310                        .status(502)
311                        .body(logging_body.boxed())?;
312                    Ok(response)
313                }
314            }
315        }
316    }
317}
318
319async fn build_normal_response(
320    pipeline_result: PipelineResult,
321    use_brotli: bool,
322    guard: RequestGuard,
323    start_time: Instant,
324    reload_token: CancellationToken,
325) -> HTTPResult {
326    let request_id = guard.request_id();
327    let (inferred_content_type, http_meta, body) = pipeline_result;
328    let status = match (http_meta.status, &body) {
329        (Some(s), _) => s,
330        (None, ResponseTransport::Empty) => 204,
331        (None, _) => 200,
332    };
333    let mut builder = hyper::Response::builder().status(status);
334    let mut header_map = hyper::header::HeaderMap::new();
335
336    // Content-type precedence:
337    // 1. Explicit in http.response headers
338    // 2. Pipeline metadata (from `to json`, etc.)
339    // 3. Inferred: record->json, binary->octet-stream, list/stream of records->ndjson, empty->None
340    // 4. Default: text/html
341    let content_type = http_meta
342        .headers
343        .get("content-type")
344        .or(http_meta.headers.get("Content-Type"))
345        .and_then(|hv| match hv {
346            crate::response::HeaderValue::Single(s) => Some(s.clone()),
347            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
348        })
349        .or(inferred_content_type)
350        .or_else(|| {
351            if matches!(body, ResponseTransport::Empty) {
352                None
353            } else {
354                Some("text/html; charset=utf-8".to_string())
355            }
356        });
357
358    if let Some(ref ct) = content_type {
359        header_map.insert(
360            hyper::header::CONTENT_TYPE,
361            hyper::header::HeaderValue::from_str(ct)?,
362        );
363    }
364
365    // Add compression headers if using brotli
366    if use_brotli {
367        header_map.insert(
368            hyper::header::CONTENT_ENCODING,
369            hyper::header::HeaderValue::from_static("br"),
370        );
371        header_map.insert(
372            hyper::header::VARY,
373            hyper::header::HeaderValue::from_static("accept-encoding"),
374        );
375    }
376
377    // Add SSE-required headers for event streams
378    let is_sse = content_type.as_deref() == Some("text/event-stream");
379    if is_sse {
380        header_map.insert(
381            hyper::header::CACHE_CONTROL,
382            hyper::header::HeaderValue::from_static("no-cache"),
383        );
384        header_map.insert(
385            hyper::header::CONNECTION,
386            hyper::header::HeaderValue::from_static("keep-alive"),
387        );
388    }
389
390    for (k, v) in &http_meta.headers {
391        if k.to_lowercase() != "content-type" {
392            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
393
394            match v {
395                crate::response::HeaderValue::Single(s) => {
396                    let header_value = hyper::header::HeaderValue::from_str(s)?;
397                    header_map.insert(header_name, header_value);
398                }
399                crate::response::HeaderValue::Multiple(values) => {
400                    for value in values {
401                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
402                            header_map.append(header_name.clone(), header_value);
403                        }
404                    }
405                }
406            }
407        }
408    }
409
410    log_response(request_id, status, &header_map, start_time);
411    *builder.headers_mut().unwrap() = header_map;
412
413    let inner_body = match body {
414        ResponseTransport::Empty => Empty::<Bytes>::new()
415            .map_err(|never| match never {})
416            .boxed(),
417        ResponseTransport::Full(bytes) => {
418            if use_brotli {
419                let compressed = compression::compress_full(&bytes)?;
420                Full::new(Bytes::from(compressed))
421                    .map_err(|never| match never {})
422                    .boxed()
423            } else {
424                Full::new(bytes.into())
425                    .map_err(|never| match never {})
426                    .boxed()
427            }
428        }
429        ResponseTransport::Stream(rx) => {
430            if use_brotli {
431                compression::compress_stream(rx)
432            } else if is_sse {
433                // SSE streams abort on reload (error triggers client retry)
434                let stream = futures_util::stream::try_unfold(
435                    (ReceiverStream::new(rx), reload_token),
436                    |(mut data_rx, token)| async move {
437                        tokio::select! {
438                            biased;
439                            _ = token.cancelled() => {
440                                Err(std::io::Error::other("reload").into())
441                            }
442                            item = StreamExt::next(&mut data_rx) => {
443                                match item {
444                                    Some(data) => Ok(Some((Frame::data(Bytes::from(data)), (data_rx, token)))),
445                                    None => Ok(None),
446                                }
447                            }
448                        }
449                    },
450                );
451                BodyExt::boxed(StreamBody::new(stream))
452            } else {
453                let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
454                BodyExt::boxed(StreamBody::new(stream))
455            }
456        }
457    };
458
459    // Wrap with LoggingBody for phase 3 (complete) logging
460    let logging_body = LoggingBody::new(inner_body, guard);
461    Ok(builder.body(logging_body.boxed())?)
462}