http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Instant;
4
5use arc_swap::ArcSwap;
6use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
7use hyper::body::{Bytes, Frame};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::StreamExt;
10use tower::Service;
11use tower_http::services::{ServeDir, ServeFile};
12
13use crate::compression;
14use crate::logging::{log_request, log_response, LoggingBody};
15use crate::request::{resolve_trusted_ip, Request};
16use crate::response::{Response, ResponseBodyType, ResponseTransport};
17use crate::worker::spawn_eval_thread;
18
19type BoxError = Box<dyn std::error::Error + Send + Sync>;
20type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
21
22pub async fn handle<B>(
23    engine: Arc<ArcSwap<crate::Engine>>,
24    addr: Option<SocketAddr>,
25    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
26    req: hyper::Request<B>,
27) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
28where
29    B: hyper::body::Body + Unpin + Send + 'static,
30    B::Data: Into<Bytes> + Clone + Send,
31    B::Error: Into<BoxError> + Send,
32{
33    // Load current engine snapshot - lock-free atomic operation
34    let engine = engine.load_full();
35    match handle_inner(engine, addr, trusted_proxies, req).await {
36        Ok(response) => Ok(response),
37        Err(err) => {
38            eprintln!("Error handling request: {err}");
39            let response = hyper::Response::builder().status(500).body(
40                Full::new(format!("Script error: {err}").into())
41                    .map_err(|never| match never {})
42                    .boxed(),
43            )?;
44            Ok(response)
45        }
46    }
47}
48
49async fn handle_inner<B>(
50    engine: Arc<crate::Engine>,
51    addr: Option<SocketAddr>,
52    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
53    req: hyper::Request<B>,
54) -> HTTPResult
55where
56    B: hyper::body::Body + Unpin + Send + 'static,
57    B::Data: Into<Bytes> + Clone + Send,
58    B::Error: Into<BoxError> + Send,
59{
60    let (parts, mut body) = req.into_parts();
61
62    // Create channels for request body streaming
63    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
64
65    // Spawn task to read request body frames
66    tokio::task::spawn(async move {
67        while let Some(frame) = body.frame().await {
68            match frame {
69                Ok(frame) => {
70                    if let Some(data) = frame.data_ref() {
71                        let bytes: Bytes = (*data).clone().into();
72                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
73                            break;
74                        }
75                    }
76                }
77                Err(err) => {
78                    let _ = body_tx.send(Err(err.into())).await;
79                    break;
80                }
81            }
82        }
83    });
84
85    // Create ByteStream for Nu pipeline
86    let stream = nu_protocol::ByteStream::from_fn(
87        nu_protocol::Span::unknown(),
88        engine.state.signals().clone(),
89        nu_protocol::ByteStreamType::Unknown,
90        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
91            Some(Ok(bytes)) => {
92                buffer.extend_from_slice(&bytes);
93                Ok(true)
94            }
95            Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
96                error: "Body read error".into(),
97                msg: err.to_string(),
98                span: None,
99                help: None,
100                inner: vec![],
101            }),
102            None => Ok(false),
103        },
104    );
105
106    // Generate request ID and capture start time for 3-phase logging
107    let start_time = Instant::now();
108    let request_id = scru128::new();
109
110    let remote_ip = addr.as_ref().map(|a| a.ip());
111    let trusted_ip = resolve_trusted_ip(&parts.headers, remote_ip, &trusted_proxies);
112
113    let request = Request {
114        proto: format!("{:?}", parts.version),
115        method: parts.method.clone(),
116        authority: parts.uri.authority().map(|a| a.to_string()),
117        remote_ip,
118        remote_port: addr.as_ref().map(|a| a.port()),
119        trusted_ip,
120        headers: parts.headers.clone(),
121        uri: parts.uri.clone(),
122        path: parts.uri.path().to_string(),
123        query: parts
124            .uri
125            .query()
126            .map(|v| {
127                url::form_urlencoded::parse(v.as_bytes())
128                    .into_owned()
129                    .collect()
130            })
131            .unwrap_or_else(std::collections::HashMap::new),
132    };
133
134    // Phase 1: Log request
135    log_request(request_id, &request);
136
137    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
138
139    // Wait for both:
140    // 1. Metadata - either from .response or default values when closure skips .response
141    // 2. Body pipeline to start (but not necessarily complete as it may stream)
142    let (meta, body_result): (
143        Response,
144        Result<(Option<String>, ResponseTransport), BoxError>,
145    ) = tokio::join!(
146        async {
147            meta_rx.await.unwrap_or(Response {
148                status: 200,
149                headers: std::collections::HashMap::new(),
150                body_type: ResponseBodyType::Normal,
151            })
152        },
153        async { bridged_body.await.map_err(|e| e.into()) }
154    );
155
156    let use_brotli = compression::accepts_brotli(&parts.headers);
157
158    match &meta.body_type {
159        ResponseBodyType::Normal => {
160            build_normal_response(&meta, Ok(body_result?), use_brotli, request_id, start_time).await
161        }
162        ResponseBodyType::Static {
163            root,
164            path,
165            fallback,
166        } => {
167            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
168            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
169            *static_req.method_mut() = parts.method.clone();
170            *static_req.headers_mut() = parts.headers.clone();
171
172            let res = if let Some(fallback) = fallback {
173                let fp = root.join(fallback);
174                ServeDir::new(root)
175                    .fallback(ServeFile::new(fp))
176                    .call(static_req)
177                    .await?
178            } else {
179                ServeDir::new(root).call(static_req).await?
180            };
181            let (res_parts, body) = res.into_parts();
182            log_response(
183                request_id,
184                res_parts.status.as_u16(),
185                &res_parts.headers,
186                start_time,
187            );
188
189            let bytes = body.collect().await?.to_bytes();
190            let inner_body = Full::new(bytes).map_err(|e| match e {}).boxed();
191            let logging_body = LoggingBody::new(inner_body, request_id);
192            let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
193            Ok(res)
194        }
195        ResponseBodyType::ReverseProxy {
196            target_url,
197            headers,
198            preserve_host,
199            strip_prefix,
200            request_body,
201            query,
202        } => {
203            let body = Full::new(Bytes::from(request_body.clone()));
204            let mut proxy_req = hyper::Request::new(body);
205
206            // Handle strip_prefix
207            let path = if let Some(prefix) = strip_prefix {
208                parts
209                    .uri
210                    .path()
211                    .strip_prefix(prefix)
212                    .unwrap_or(parts.uri.path())
213            } else {
214                parts.uri.path()
215            };
216
217            // Build target URI
218            let target_uri = {
219                let query_string = if let Some(custom_query) = query {
220                    // Use custom query - convert HashMap to query string
221                    url::form_urlencoded::Serializer::new(String::new())
222                        .extend_pairs(custom_query.iter())
223                        .finish()
224                } else if let Some(orig_query) = parts.uri.query() {
225                    // Use original query string
226                    orig_query.to_string()
227                } else {
228                    String::new()
229                };
230
231                if query_string.is_empty() {
232                    format!("{target_url}{path}")
233                } else {
234                    format!("{target_url}{path}?{query_string}")
235                }
236            };
237
238            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
239            *proxy_req.method_mut() = parts.method.clone();
240
241            // Copy original headers
242            let mut header_map = parts.headers.clone();
243
244            // Update Content-Length to match the new body
245            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
246                header_map.insert(
247                    hyper::header::CONTENT_LENGTH,
248                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
249                );
250            }
251
252            // Add custom headers
253            for (k, v) in headers {
254                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
255
256                match v {
257                    crate::response::HeaderValue::Single(s) => {
258                        let header_value = hyper::header::HeaderValue::from_str(s)?;
259                        header_map.insert(header_name, header_value);
260                    }
261                    crate::response::HeaderValue::Multiple(values) => {
262                        for value in values {
263                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
264                                header_map.append(header_name.clone(), header_value);
265                            }
266                        }
267                    }
268                }
269            }
270
271            // Handle preserve_host
272            if !preserve_host {
273                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
274                    if let Some(authority) = target_uri.authority() {
275                        header_map.insert(
276                            hyper::header::HOST,
277                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
278                        );
279                    }
280                }
281            }
282
283            *proxy_req.headers_mut() = header_map;
284
285            // Create a simple HTTP client and forward the request
286            let client =
287                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
288                    .build_http();
289
290            match client.request(proxy_req).await {
291                Ok(response) => {
292                    let (res_parts, body) = response.into_parts();
293                    log_response(
294                        request_id,
295                        res_parts.status.as_u16(),
296                        &res_parts.headers,
297                        start_time,
298                    );
299
300                    let inner_body = body.map_err(|e| e.into()).boxed();
301                    let logging_body = LoggingBody::new(inner_body, request_id);
302                    let res = hyper::Response::from_parts(res_parts, logging_body.boxed());
303                    Ok(res)
304                }
305                Err(_e) => {
306                    let empty_headers = hyper::header::HeaderMap::new();
307                    log_response(request_id, 502, &empty_headers, start_time);
308
309                    let inner_body = Full::new("Bad Gateway".into())
310                        .map_err(|never| match never {})
311                        .boxed();
312                    let logging_body = LoggingBody::new(inner_body, request_id);
313                    let response = hyper::Response::builder()
314                        .status(502)
315                        .body(logging_body.boxed())?;
316                    Ok(response)
317                }
318            }
319        }
320    }
321}
322
323async fn build_normal_response(
324    meta: &Response,
325    body_result: Result<(Option<String>, ResponseTransport), BoxError>,
326    use_brotli: bool,
327    request_id: scru128::Scru128Id,
328    start_time: Instant,
329) -> HTTPResult {
330    let (inferred_content_type, body) = body_result?;
331    let mut builder = hyper::Response::builder().status(meta.status);
332    let mut header_map = hyper::header::HeaderMap::new();
333
334    let content_type = meta
335        .headers
336        .get("content-type")
337        .or(meta.headers.get("Content-Type"))
338        .and_then(|hv| match hv {
339            crate::response::HeaderValue::Single(s) => Some(s.clone()),
340            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
341        })
342        .or(inferred_content_type)
343        .unwrap_or("text/html; charset=utf-8".to_string());
344
345    header_map.insert(
346        hyper::header::CONTENT_TYPE,
347        hyper::header::HeaderValue::from_str(&content_type)?,
348    );
349
350    // Add compression headers if using brotli
351    if use_brotli {
352        header_map.insert(
353            hyper::header::CONTENT_ENCODING,
354            hyper::header::HeaderValue::from_static("br"),
355        );
356        header_map.insert(
357            hyper::header::VARY,
358            hyper::header::HeaderValue::from_static("accept-encoding"),
359        );
360    }
361
362    // Add SSE-required headers for event streams
363    if content_type == "text/event-stream" {
364        header_map.insert(
365            hyper::header::CACHE_CONTROL,
366            hyper::header::HeaderValue::from_static("no-cache"),
367        );
368        header_map.insert(
369            hyper::header::CONNECTION,
370            hyper::header::HeaderValue::from_static("keep-alive"),
371        );
372    }
373
374    for (k, v) in &meta.headers {
375        if k.to_lowercase() != "content-type" {
376            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
377
378            match v {
379                crate::response::HeaderValue::Single(s) => {
380                    let header_value = hyper::header::HeaderValue::from_str(s)?;
381                    header_map.insert(header_name, header_value);
382                }
383                crate::response::HeaderValue::Multiple(values) => {
384                    for value in values {
385                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
386                            header_map.append(header_name.clone(), header_value);
387                        }
388                    }
389                }
390            }
391        }
392    }
393
394    log_response(request_id, meta.status, &header_map, start_time);
395    *builder.headers_mut().unwrap() = header_map;
396
397    let inner_body = match body {
398        ResponseTransport::Empty => Empty::<Bytes>::new()
399            .map_err(|never| match never {})
400            .boxed(),
401        ResponseTransport::Full(bytes) => {
402            if use_brotli {
403                let compressed = compression::compress_full(&bytes)?;
404                Full::new(Bytes::from(compressed))
405                    .map_err(|never| match never {})
406                    .boxed()
407            } else {
408                Full::new(bytes.into())
409                    .map_err(|never| match never {})
410                    .boxed()
411            }
412        }
413        ResponseTransport::Stream(rx) => {
414            if use_brotli {
415                compression::compress_stream(rx)
416            } else {
417                let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
418                StreamBody::new(stream).boxed()
419            }
420        }
421    };
422
423    // Wrap with LoggingBody for phase 3 (complete) logging
424    let logging_body = LoggingBody::new(inner_body, request_id);
425    Ok(builder.body(logging_body.boxed())?)
426}