http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
7use hyper::body::{Bytes, Frame};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::StreamExt;
10use tower::Service;
11use tower_http::services::{ServeDir, ServeFile};
12
13use crate::compression;
14use crate::logging::{log_request, log_response, LoggingBody, RequestGuard};
15use crate::request::{resolve_trusted_ip, Request};
16use crate::response::{Response, ResponseBodyType, ResponseTransport};
17use crate::worker::spawn_eval_thread;
18
19type BoxError = Box<dyn std::error::Error + Send + Sync>;
20type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
21
22pub async fn handle<B>(
23    engine: Arc<ArcSwap<crate::Engine>>,
24    addr: Option<SocketAddr>,
25    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
26    req: hyper::Request<B>,
27) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
28where
29    B: hyper::body::Body + Unpin + Send + 'static,
30    B::Data: Into<Bytes> + Clone + Send,
31    B::Error: Into<BoxError> + Send,
32{
33    // Load current engine snapshot - lock-free atomic operation
34    let engine = engine.load_full();
35    match handle_inner(engine, addr, trusted_proxies, req).await {
36        Ok(response) => Ok(response),
37        Err(err) => {
38            eprintln!("Error handling request: {err}");
39            let response = hyper::Response::builder().status(500).body(
40                Full::new(format!("Script error: {err}").into())
41                    .map_err(|never| match never {})
42                    .boxed(),
43            )?;
44            Ok(response)
45        }
46    }
47}
48
49async fn handle_inner<B>(
50    engine: Arc<crate::Engine>,
51    addr: Option<SocketAddr>,
52    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
53    req: hyper::Request<B>,
54) -> HTTPResult
55where
56    B: hyper::body::Body + Unpin + Send + 'static,
57    B::Data: Into<Bytes> + Clone + Send,
58    B::Error: Into<BoxError> + Send,
59{
60    let (parts, mut body) = req.into_parts();
61
62    // Create channels for request body streaming
63    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
64
65    // Spawn task to read request body frames
66    tokio::task::spawn(async move {
67        while let Some(frame) = body.frame().await {
68            match frame {
69                Ok(frame) => {
70                    if let Some(data) = frame.data_ref() {
71                        let bytes: Bytes = (*data).clone().into();
72                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
73                            break;
74                        }
75                    }
76                }
77                Err(err) => {
78                    let _ = body_tx.send(Err(err.into())).await;
79                    break;
80                }
81            }
82        }
83    });
84
85    // Create ByteStream for Nu pipeline
86    let stream = nu_protocol::ByteStream::from_fn(
87        nu_protocol::Span::unknown(),
88        engine.state.signals().clone(),
89        nu_protocol::ByteStreamType::Unknown,
90        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
91            Some(Ok(bytes)) => {
92                buffer.extend_from_slice(&bytes);
93                Ok(true)
94            }
95            Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
96                error: "Body read error".into(),
97                msg: err.to_string(),
98                span: None,
99                help: None,
100                inner: vec![],
101            }),
102            None => Ok(false),
103        },
104    );
105
106    // Generate request ID and guard for logging
107    let start_time = Instant::now();
108    let request_id = scru128::new();
109    let guard = RequestGuard::new(request_id);
110
111    let remote_ip = addr.as_ref().map(|a| a.ip());
112    let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &trusted_proxies);
113
114    let request = Request {
115        proto: format!("{:?}", parts.version),
116        method: parts.method.clone(),
117        authority: parts.uri.authority().map(|a| a.to_string()),
118        remote_ip,
119        remote_port: addr.as_ref().map(|a| a.port()),
120        trusted_ip,
121        headers: parts.headers.clone(),
122        uri: parts.uri.clone(),
123        path: parts.uri.path().to_string(),
124        query: parts
125            .uri
126            .query()
127            .map(|v| {
128                url::form_urlencoded::parse(v.as_bytes())
129                    .into_owned()
130                    .collect()
131            })
132            .unwrap_or_else(std::collections::HashMap::new),
133    };
134
135    // Phase 1: Log request
136    log_request(request_id, &request);
137
138    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
139
140    // Wait for both:
141    // 1. Metadata - either from .response or default values when closure skips .response
142    // 2. Body pipeline to start (but not necessarily complete as it may stream)
143    let (meta, body_result): (
144        Response,
145        Result<(Option<String>, ResponseTransport), BoxError>,
146    ) = tokio::join!(
147        async {
148            meta_rx.await.unwrap_or(Response {
149                status: 200,
150                headers: std::collections::HashMap::new(),
151                body_type: ResponseBodyType::Normal,
152            })
153        },
154        async { bridged_body.await.map_err(|e| e.into()) }
155    );
156
157    let use_brotli = compression::accepts_brotli(&parts.headers);
158
159    match &meta.body_type {
160        ResponseBodyType::Normal => {
161            build_normal_response(&meta, Ok(body_result?), use_brotli, guard, start_time).await
162        }
163        ResponseBodyType::Static {
164            root,
165            path,
166            fallback,
167        } => {
168            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
169            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
170            *static_req.method_mut() = parts.method.clone();
171            *static_req.headers_mut() = parts.headers.clone();
172
173            let res = if let Some(fallback) = fallback {
174                let fp = root.join(fallback);
175                ServeDir::new(root)
176                    .fallback(ServeFile::new(fp))
177                    .call(static_req)
178                    .await?
179            } else {
180                ServeDir::new(root).call(static_req).await?
181            };
182            let (res_parts, body) = res.into_parts();
183            log_response(
184                request_id,
185                res_parts.status.as_u16(),
186                &res_parts.headers,
187                start_time,
188            );
189
190            let bytes = body.collect().await?.to_bytes();
191            let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
192            let logging_body = LoggingBody::new(inner_body, guard);
193            let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
194            Ok(res)
195        }
196        ResponseBodyType::ReverseProxy {
197            target_url,
198            headers,
199            preserve_host,
200            strip_prefix,
201            request_body,
202            query,
203        } => {
204            let body = Full::new(Bytes::from(request_body.clone()));
205            let mut proxy_req = hyper::Request::new(body);
206
207            // Handle strip_prefix
208            let path = if let Some(prefix) = strip_prefix {
209                parts
210                    .uri
211                    .path()
212                    .strip_prefix(prefix)
213                    .unwrap_or(parts.uri.path())
214            } else {
215                parts.uri.path()
216            };
217
218            // Build target URI
219            let target_uri = {
220                let query_string = if let Some(custom_query) = query {
221                    // Use custom query - convert HashMap to query string
222                    url::form_urlencoded::Serializer::new(String::new())
223                        .extend_pairs(custom_query.iter())
224                        .finish()
225                } else if let Some(orig_query) = parts.uri.query() {
226                    // Use original query string
227                    orig_query.to_string()
228                } else {
229                    String::new()
230                };
231
232                if query_string.is_empty() {
233                    format!("{target_url}{path}")
234                } else {
235                    format!("{target_url}{path}?{query_string}")
236                }
237            };
238
239            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
240            *proxy_req.method_mut() = parts.method.clone();
241
242            // Copy original headers
243            let mut header_map = parts.headers.clone();
244
245            // Update Content-Length to match the new body
246            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
247                header_map.insert(
248                    hyper::header::CONTENT_LENGTH,
249                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
250                );
251            }
252
253            // Add custom headers
254            for (k, v) in headers {
255                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
256
257                match v {
258                    crate::response::HeaderValue::Single(s) => {
259                        let header_value = hyper::header::HeaderValue::from_str(s)?;
260                        header_map.insert(header_name, header_value);
261                    }
262                    crate::response::HeaderValue::Multiple(values) => {
263                        for value in values {
264                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
265                                header_map.append(header_name.clone(), header_value);
266                            }
267                        }
268                    }
269                }
270            }
271
272            // Handle preserve_host
273            if !preserve_host {
274                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
275                    if let Some(authority) = target_uri.authority() {
276                        header_map.insert(
277                            hyper::header::HOST,
278                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
279                        );
280                    }
281                }
282            }
283
284            *proxy_req.headers_mut() = header_map;
285
286            // Create a simple HTTP client and forward the request
287            let client =
288                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
289                    .build_http();
290
291            match client.request(proxy_req).await {
292                Ok(response) => {
293                    let (res_parts, body) = response.into_parts();
294                    log_response(
295                        request_id,
296                        res_parts.status.as_u16(),
297                        &res_parts.headers,
298                        start_time,
299                    );
300
301                    let inner_body = body.map_err(|e| e.into()).boxed();
302                    let logging_body = LoggingBody::new(inner_body, guard);
303                    let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
304                    Ok(res)
305                }
306                Err(_e) => {
307                    let empty_headers = hyper::header::HeaderMap::new();
308                    log_response(request_id, 502, &empty_headers, start_time);
309
310                    let inner_body = Full::new("Bad Gateway".into())
311                        .map_err(|never| match never {})
312                        .boxed();
313                    let logging_body = LoggingBody::new(inner_body, guard);
314                    let response = hyper::Response::builder()
315                        .status(502)
316                        .body(logging_body.boxed())?;
317                    Ok(response)
318                }
319            }
320        }
321    }
322}
323
324async fn build_normal_response(
325    meta: &Response,
326    body_result: Result<(Option<String>, ResponseTransport), BoxError>,
327    use_brotli: bool,
328    guard: RequestGuard,
329    start_time: Instant,
330) -> HTTPResult {
331    let request_id = guard.request_id();
332    let (inferred_content_type, body) = body_result?;
333    let mut builder = hyper::Response::builder().status(meta.status);
334    let mut header_map = hyper::header::HeaderMap::new();
335
336    let content_type = meta
337        .headers
338        .get("content-type")
339        .or(meta.headers.get("Content-Type"))
340        .and_then(|hv| match hv {
341            crate::response::HeaderValue::Single(s) => Some(s.clone()),
342            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
343        })
344        .or(inferred_content_type)
345        .unwrap_or("text/html; charset=utf-8".to_string());
346
347    header_map.insert(
348        hyper::header::CONTENT_TYPE,
349        hyper::header::HeaderValue::from_str(&content_type)?,
350    );
351
352    // Add compression headers if using brotli
353    if use_brotli {
354        header_map.insert(
355            hyper::header::CONTENT_ENCODING,
356            hyper::header::HeaderValue::from_static("br"),
357        );
358        header_map.insert(
359            hyper::header::VARY,
360            hyper::header::HeaderValue::from_static("accept-encoding"),
361        );
362    }
363
364    // Add SSE-required headers for event streams
365    if content_type == "text/event-stream" {
366        header_map.insert(
367            hyper::header::CACHE_CONTROL,
368            hyper::header::HeaderValue::from_static("no-cache"),
369        );
370        header_map.insert(
371            hyper::header::CONNECTION,
372            hyper::header::HeaderValue::from_static("keep-alive"),
373        );
374    }
375
376    for (k, v) in &meta.headers {
377        if k.to_lowercase() != "content-type" {
378            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
379
380            match v {
381                crate::response::HeaderValue::Single(s) => {
382                    let header_value = hyper::header::HeaderValue::from_str(s)?;
383                    header_map.insert(header_name, header_value);
384                }
385                crate::response::HeaderValue::Multiple(values) => {
386                    for value in values {
387                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
388                            header_map.append(header_name.clone(), header_value);
389                        }
390                    }
391                }
392            }
393        }
394    }
395
396    log_response(request_id, meta.status, &header_map, start_time);
397    *builder.headers_mut().unwrap() = header_map;
398
399    let inner_body = match body {
400        ResponseTransport::Empty => Empty::<Bytes>::new()
401            .map_err(|never| match never {})
402            .boxed(),
403        ResponseTransport::Full(bytes) => {
404            if use_brotli {
405                let compressed = compression::compress_full(&bytes)?;
406                Full::new(Bytes::from(compressed))
407                    .map_err(|never| match never {})
408                    .boxed()
409            } else {
410                Full::new(bytes.into())
411                    .map_err(|never| match never {})
412                    .boxed()
413            }
414        }
415        ResponseTransport::Stream(rx) => {
416            if use_brotli {
417                compression::compress_stream(rx)
418            } else {
419                let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
420                StreamBody::new(stream).boxed()
421            }
422        }
423    };
424
425    // Wrap with LoggingBody for phase 3 (complete) logging
426    let logging_body = LoggingBody::new(inner_body, guard);
427    Ok(builder.body(logging_body.boxed())?)
428}