http_nu/
handler.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use arc_swap::ArcSwap;
5use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
6use hyper::body::{Bytes, Frame};
7use tokio_stream::wrappers::ReceiverStream;
8use tokio_stream::StreamExt;
9use tower::Service;
10use tower_http::services::{ServeDir, ServeFile};
11
12use crate::compression;
13use crate::request::Request;
14use crate::response::{Response, ResponseBodyType, ResponseTransport};
15use crate::worker::spawn_eval_thread;
16
17type BoxError = Box<dyn std::error::Error + Send + Sync>;
18type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
19
20pub async fn handle<B>(
21    engine: Arc<ArcSwap<crate::Engine>>,
22    addr: Option<SocketAddr>,
23    req: hyper::Request<B>,
24) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
25where
26    B: hyper::body::Body + Unpin + Send + 'static,
27    B::Data: Into<Bytes> + Clone + Send,
28    B::Error: Into<BoxError> + Send,
29{
30    // Load current engine snapshot - lock-free atomic operation
31    let engine = engine.load_full();
32    match handle_inner(engine, addr, req).await {
33        Ok(response) => Ok(response),
34        Err(err) => {
35            eprintln!("Error handling request: {err}");
36            let response = hyper::Response::builder().status(500).body(
37                Full::new(format!("Script error: {err}").into())
38                    .map_err(|never| match never {})
39                    .boxed(),
40            )?;
41            Ok(response)
42        }
43    }
44}
45
46async fn handle_inner<B>(
47    engine: Arc<crate::Engine>,
48    addr: Option<SocketAddr>,
49    req: hyper::Request<B>,
50) -> HTTPResult
51where
52    B: hyper::body::Body + Unpin + Send + 'static,
53    B::Data: Into<Bytes> + Clone + Send,
54    B::Error: Into<BoxError> + Send,
55{
56    let (parts, mut body) = req.into_parts();
57
58    // Create channels for request body streaming
59    let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
60
61    // Spawn task to read request body frames
62    tokio::task::spawn(async move {
63        while let Some(frame) = body.frame().await {
64            match frame {
65                Ok(frame) => {
66                    if let Some(data) = frame.data_ref() {
67                        let bytes: Bytes = (*data).clone().into();
68                        if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
69                            break;
70                        }
71                    }
72                }
73                Err(err) => {
74                    let _ = body_tx.send(Err(err.into())).await;
75                    break;
76                }
77            }
78        }
79    });
80
81    // Create ByteStream for Nu pipeline
82    let stream = nu_protocol::ByteStream::from_fn(
83        nu_protocol::Span::unknown(),
84        engine.state.signals().clone(),
85        nu_protocol::ByteStreamType::Unknown,
86        move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
87            Some(Ok(bytes)) => {
88                buffer.extend_from_slice(&bytes);
89                Ok(true)
90            }
91            Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
92                error: "Body read error".into(),
93                msg: err.to_string(),
94                span: None,
95                help: None,
96                inner: vec![],
97            }),
98            None => Ok(false),
99        },
100    );
101
102    let request = Request {
103        proto: format!("{:?}", parts.version),
104        method: parts.method.clone(),
105        authority: parts.uri.authority().map(|a| a.to_string()),
106        remote_ip: addr.as_ref().map(|a| a.ip()),
107        remote_port: addr.as_ref().map(|a| a.port()),
108        headers: parts.headers.clone(),
109        uri: parts.uri.clone(),
110        path: parts.uri.path().to_string(),
111        query: parts
112            .uri
113            .query()
114            .map(|v| {
115                url::form_urlencoded::parse(v.as_bytes())
116                    .into_owned()
117                    .collect()
118            })
119            .unwrap_or_else(std::collections::HashMap::new),
120    };
121
122    println!(
123        "{}",
124        serde_json::json!({"stamp": scru128::new(), "message": "request", "meta": request})
125    );
126
127    let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
128
129    // Wait for both:
130    // 1. Metadata - either from .response or default values when closure skips .response
131    // 2. Body pipeline to start (but not necessarily complete as it may stream)
132    let (meta, body_result): (
133        Response,
134        Result<(Option<String>, ResponseTransport), BoxError>,
135    ) = tokio::join!(
136        async {
137            meta_rx.await.unwrap_or(Response {
138                status: 200,
139                headers: std::collections::HashMap::new(),
140                body_type: ResponseBodyType::Normal,
141            })
142        },
143        async { bridged_body.await.map_err(|e| e.into()) }
144    );
145
146    let use_brotli = compression::accepts_brotli(&parts.headers);
147
148    match &meta.body_type {
149        ResponseBodyType::Normal => {
150            build_normal_response(&meta, Ok(body_result?), use_brotli).await
151        }
152        ResponseBodyType::Static {
153            root,
154            path,
155            fallback,
156        } => {
157            let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
158            *static_req.uri_mut() = format!("/{path}").parse().unwrap();
159            *static_req.method_mut() = parts.method.clone();
160            *static_req.headers_mut() = parts.headers.clone();
161
162            let res = if let Some(fallback) = fallback {
163                let fp = root.join(fallback);
164                ServeDir::new(root)
165                    .fallback(ServeFile::new(fp))
166                    .call(static_req)
167                    .await?
168            } else {
169                ServeDir::new(root).call(static_req).await?
170            };
171            let (parts, body) = res.into_parts();
172            let bytes = body.collect().await?.to_bytes();
173            let res = hyper::Response::from_parts(
174                parts,
175                Full::new(bytes).map_err(|e| match e {}).boxed(),
176            );
177            Ok(res)
178        }
179        ResponseBodyType::ReverseProxy {
180            target_url,
181            headers,
182            preserve_host,
183            strip_prefix,
184            request_body,
185            query,
186        } => {
187            let body = Full::new(Bytes::from(request_body.clone()));
188            let mut proxy_req = hyper::Request::new(body);
189
190            // Handle strip_prefix
191            let path = if let Some(prefix) = strip_prefix {
192                parts
193                    .uri
194                    .path()
195                    .strip_prefix(prefix)
196                    .unwrap_or(parts.uri.path())
197            } else {
198                parts.uri.path()
199            };
200
201            // Build target URI
202            let target_uri = {
203                let query_string = if let Some(custom_query) = query {
204                    // Use custom query - convert HashMap to query string
205                    url::form_urlencoded::Serializer::new(String::new())
206                        .extend_pairs(custom_query.iter())
207                        .finish()
208                } else if let Some(orig_query) = parts.uri.query() {
209                    // Use original query string
210                    orig_query.to_string()
211                } else {
212                    String::new()
213                };
214
215                if query_string.is_empty() {
216                    format!("{target_url}{path}")
217                } else {
218                    format!("{target_url}{path}?{query_string}")
219                }
220            };
221
222            *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
223            *proxy_req.method_mut() = parts.method.clone();
224
225            // Copy original headers
226            let mut header_map = parts.headers.clone();
227
228            // Update Content-Length to match the new body
229            if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
230                header_map.insert(
231                    hyper::header::CONTENT_LENGTH,
232                    hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
233                );
234            }
235
236            // Add custom headers
237            for (k, v) in headers {
238                let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
239
240                match v {
241                    crate::response::HeaderValue::Single(s) => {
242                        let header_value = hyper::header::HeaderValue::from_str(s)?;
243                        header_map.insert(header_name, header_value);
244                    }
245                    crate::response::HeaderValue::Multiple(values) => {
246                        for value in values {
247                            if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
248                                header_map.append(header_name.clone(), header_value);
249                            }
250                        }
251                    }
252                }
253            }
254
255            // Handle preserve_host
256            if !preserve_host {
257                if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
258                    if let Some(authority) = target_uri.authority() {
259                        header_map.insert(
260                            hyper::header::HOST,
261                            hyper::header::HeaderValue::from_str(authority.as_ref())?,
262                        );
263                    }
264                }
265            }
266
267            *proxy_req.headers_mut() = header_map;
268
269            // Create a simple HTTP client and forward the request
270            let client =
271                hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
272                    .build_http();
273
274            match client.request(proxy_req).await {
275                Ok(response) => {
276                    let (parts, body) = response.into_parts();
277                    // Stream the response body directly without buffering
278                    let res =
279                        hyper::Response::from_parts(parts, body.map_err(|e| e.into()).boxed());
280                    Ok(res)
281                }
282                Err(_e) => {
283                    let response = hyper::Response::builder().status(502).body(
284                        Full::new("Bad Gateway".into())
285                            .map_err(|never| match never {})
286                            .boxed(),
287                    )?;
288                    Ok(response)
289                }
290            }
291        }
292    }
293}
294
295async fn build_normal_response(
296    meta: &Response,
297    body_result: Result<(Option<String>, ResponseTransport), BoxError>,
298    use_brotli: bool,
299) -> HTTPResult {
300    let (inferred_content_type, body) = body_result?;
301    let mut builder = hyper::Response::builder().status(meta.status);
302    let mut header_map = hyper::header::HeaderMap::new();
303
304    let content_type = meta
305        .headers
306        .get("content-type")
307        .or(meta.headers.get("Content-Type"))
308        .and_then(|hv| match hv {
309            crate::response::HeaderValue::Single(s) => Some(s.clone()),
310            crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
311        })
312        .or(inferred_content_type)
313        .unwrap_or("text/html; charset=utf-8".to_string());
314
315    header_map.insert(
316        hyper::header::CONTENT_TYPE,
317        hyper::header::HeaderValue::from_str(&content_type)?,
318    );
319
320    // Add compression headers if using brotli
321    if use_brotli {
322        header_map.insert(
323            hyper::header::CONTENT_ENCODING,
324            hyper::header::HeaderValue::from_static("br"),
325        );
326        header_map.insert(
327            hyper::header::VARY,
328            hyper::header::HeaderValue::from_static("accept-encoding"),
329        );
330    }
331
332    for (k, v) in &meta.headers {
333        if k.to_lowercase() != "content-type" {
334            let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
335
336            match v {
337                crate::response::HeaderValue::Single(s) => {
338                    let header_value = hyper::header::HeaderValue::from_str(s)?;
339                    header_map.insert(header_name, header_value);
340                }
341                crate::response::HeaderValue::Multiple(values) => {
342                    for value in values {
343                        if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
344                            header_map.append(header_name.clone(), header_value);
345                        }
346                    }
347                }
348            }
349        }
350    }
351
352    *builder.headers_mut().unwrap() = header_map;
353
354    let body = match body {
355        ResponseTransport::Empty => Empty::<Bytes>::new()
356            .map_err(|never| match never {})
357            .boxed(),
358        ResponseTransport::Full(bytes) => {
359            if use_brotli {
360                let compressed = compression::compress_full(&bytes)?;
361                Full::new(Bytes::from(compressed))
362                    .map_err(|never| match never {})
363                    .boxed()
364            } else {
365                Full::new(bytes.into())
366                    .map_err(|never| match never {})
367                    .boxed()
368            }
369        }
370        ResponseTransport::Stream(rx) => {
371            if use_brotli {
372                compression::compress_stream(rx)
373            } else {
374                let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
375                StreamBody::new(stream).boxed()
376            }
377        }
378    };
379
380    Ok(builder.body(body)?)
381}