1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use arc_swap::ArcSwap;
5use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full, StreamBody};
6use hyper::body::{Bytes, Frame};
7use tokio_stream::wrappers::ReceiverStream;
8use tokio_stream::StreamExt;
9use tower::Service;
10use tower_http::services::{ServeDir, ServeFile};
11
12use crate::compression;
13use crate::request::Request;
14use crate::response::{Response, ResponseBodyType, ResponseTransport};
15use crate::worker::spawn_eval_thread;
16
17type BoxError = Box<dyn std::error::Error + Send + Sync>;
18type HTTPResult = Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>;
19
20pub async fn handle<B>(
21 engine: Arc<ArcSwap<crate::Engine>>,
22 addr: Option<SocketAddr>,
23 req: hyper::Request<B>,
24) -> Result<hyper::Response<BoxBody<Bytes, BoxError>>, BoxError>
25where
26 B: hyper::body::Body + Unpin + Send + 'static,
27 B::Data: Into<Bytes> + Clone + Send,
28 B::Error: Into<BoxError> + Send,
29{
30 let engine = engine.load_full();
32 match handle_inner(engine, addr, req).await {
33 Ok(response) => Ok(response),
34 Err(err) => {
35 eprintln!("Error handling request: {err}");
36 let response = hyper::Response::builder().status(500).body(
37 Full::new(format!("Script error: {err}").into())
38 .map_err(|never| match never {})
39 .boxed(),
40 )?;
41 Ok(response)
42 }
43 }
44}
45
46async fn handle_inner<B>(
47 engine: Arc<crate::Engine>,
48 addr: Option<SocketAddr>,
49 req: hyper::Request<B>,
50) -> HTTPResult
51where
52 B: hyper::body::Body + Unpin + Send + 'static,
53 B::Data: Into<Bytes> + Clone + Send,
54 B::Error: Into<BoxError> + Send,
55{
56 let (parts, mut body) = req.into_parts();
57
58 let (body_tx, mut body_rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, BoxError>>(32);
60
61 tokio::task::spawn(async move {
63 while let Some(frame) = body.frame().await {
64 match frame {
65 Ok(frame) => {
66 if let Some(data) = frame.data_ref() {
67 let bytes: Bytes = (*data).clone().into();
68 if body_tx.send(Ok(bytes.to_vec())).await.is_err() {
69 break;
70 }
71 }
72 }
73 Err(err) => {
74 let _ = body_tx.send(Err(err.into())).await;
75 break;
76 }
77 }
78 }
79 });
80
81 let stream = nu_protocol::ByteStream::from_fn(
83 nu_protocol::Span::unknown(),
84 engine.state.signals().clone(),
85 nu_protocol::ByteStreamType::Unknown,
86 move |buffer: &mut Vec<u8>| match body_rx.blocking_recv() {
87 Some(Ok(bytes)) => {
88 buffer.extend_from_slice(&bytes);
89 Ok(true)
90 }
91 Some(Err(err)) => Err(nu_protocol::ShellError::GenericError {
92 error: "Body read error".into(),
93 msg: err.to_string(),
94 span: None,
95 help: None,
96 inner: vec![],
97 }),
98 None => Ok(false),
99 },
100 );
101
102 let request = Request {
103 proto: format!("{:?}", parts.version),
104 method: parts.method.clone(),
105 authority: parts.uri.authority().map(|a| a.to_string()),
106 remote_ip: addr.as_ref().map(|a| a.ip()),
107 remote_port: addr.as_ref().map(|a| a.port()),
108 headers: parts.headers.clone(),
109 uri: parts.uri.clone(),
110 path: parts.uri.path().to_string(),
111 query: parts
112 .uri
113 .query()
114 .map(|v| {
115 url::form_urlencoded::parse(v.as_bytes())
116 .into_owned()
117 .collect()
118 })
119 .unwrap_or_else(std::collections::HashMap::new),
120 };
121
122 println!(
123 "{}",
124 serde_json::json!({"stamp": scru128::new(), "message": "request", "meta": request})
125 );
126
127 let (meta_rx, bridged_body) = spawn_eval_thread(engine, request, stream);
128
129 let (meta, body_result): (
133 Response,
134 Result<(Option<String>, ResponseTransport), BoxError>,
135 ) = tokio::join!(
136 async {
137 meta_rx.await.unwrap_or(Response {
138 status: 200,
139 headers: std::collections::HashMap::new(),
140 body_type: ResponseBodyType::Normal,
141 })
142 },
143 async { bridged_body.await.map_err(|e| e.into()) }
144 );
145
146 let use_brotli = compression::accepts_brotli(&parts.headers);
147
148 match &meta.body_type {
149 ResponseBodyType::Normal => {
150 build_normal_response(&meta, Ok(body_result?), use_brotli).await
151 }
152 ResponseBodyType::Static {
153 root,
154 path,
155 fallback,
156 } => {
157 let mut static_req = hyper::Request::new(Empty::<Bytes>::new());
158 *static_req.uri_mut() = format!("/{path}").parse().unwrap();
159 *static_req.method_mut() = parts.method.clone();
160 *static_req.headers_mut() = parts.headers.clone();
161
162 let res = if let Some(fallback) = fallback {
163 let fp = root.join(fallback);
164 ServeDir::new(root)
165 .fallback(ServeFile::new(fp))
166 .call(static_req)
167 .await?
168 } else {
169 ServeDir::new(root).call(static_req).await?
170 };
171 let (parts, body) = res.into_parts();
172 let bytes = body.collect().await?.to_bytes();
173 let res = hyper::Response::from_parts(
174 parts,
175 Full::new(bytes).map_err(|e| match e {}).boxed(),
176 );
177 Ok(res)
178 }
179 ResponseBodyType::ReverseProxy {
180 target_url,
181 headers,
182 preserve_host,
183 strip_prefix,
184 request_body,
185 query,
186 } => {
187 let body = Full::new(Bytes::from(request_body.clone()));
188 let mut proxy_req = hyper::Request::new(body);
189
190 let path = if let Some(prefix) = strip_prefix {
192 parts
193 .uri
194 .path()
195 .strip_prefix(prefix)
196 .unwrap_or(parts.uri.path())
197 } else {
198 parts.uri.path()
199 };
200
201 let target_uri = {
203 let query_string = if let Some(custom_query) = query {
204 url::form_urlencoded::Serializer::new(String::new())
206 .extend_pairs(custom_query.iter())
207 .finish()
208 } else if let Some(orig_query) = parts.uri.query() {
209 orig_query.to_string()
211 } else {
212 String::new()
213 };
214
215 if query_string.is_empty() {
216 format!("{target_url}{path}")
217 } else {
218 format!("{target_url}{path}?{query_string}")
219 }
220 };
221
222 *proxy_req.uri_mut() = target_uri.parse().map_err(|e| Box::new(e) as BoxError)?;
223 *proxy_req.method_mut() = parts.method.clone();
224
225 let mut header_map = parts.headers.clone();
227
228 if !request_body.is_empty() || header_map.contains_key(hyper::header::CONTENT_LENGTH) {
230 header_map.insert(
231 hyper::header::CONTENT_LENGTH,
232 hyper::header::HeaderValue::from_str(&request_body.len().to_string())?,
233 );
234 }
235
236 for (k, v) in headers {
238 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
239
240 match v {
241 crate::response::HeaderValue::Single(s) => {
242 let header_value = hyper::header::HeaderValue::from_str(s)?;
243 header_map.insert(header_name, header_value);
244 }
245 crate::response::HeaderValue::Multiple(values) => {
246 for value in values {
247 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
248 header_map.append(header_name.clone(), header_value);
249 }
250 }
251 }
252 }
253 }
254
255 if !preserve_host {
257 if let Ok(target_uri) = target_url.parse::<hyper::Uri>() {
258 if let Some(authority) = target_uri.authority() {
259 header_map.insert(
260 hyper::header::HOST,
261 hyper::header::HeaderValue::from_str(authority.as_ref())?,
262 );
263 }
264 }
265 }
266
267 *proxy_req.headers_mut() = header_map;
268
269 let client =
271 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
272 .build_http();
273
274 match client.request(proxy_req).await {
275 Ok(response) => {
276 let (parts, body) = response.into_parts();
277 let res =
279 hyper::Response::from_parts(parts, body.map_err(|e| e.into()).boxed());
280 Ok(res)
281 }
282 Err(_e) => {
283 let response = hyper::Response::builder().status(502).body(
284 Full::new("Bad Gateway".into())
285 .map_err(|never| match never {})
286 .boxed(),
287 )?;
288 Ok(response)
289 }
290 }
291 }
292 }
293}
294
295async fn build_normal_response(
296 meta: &Response,
297 body_result: Result<(Option<String>, ResponseTransport), BoxError>,
298 use_brotli: bool,
299) -> HTTPResult {
300 let (inferred_content_type, body) = body_result?;
301 let mut builder = hyper::Response::builder().status(meta.status);
302 let mut header_map = hyper::header::HeaderMap::new();
303
304 let content_type = meta
305 .headers
306 .get("content-type")
307 .or(meta.headers.get("Content-Type"))
308 .and_then(|hv| match hv {
309 crate::response::HeaderValue::Single(s) => Some(s.clone()),
310 crate::response::HeaderValue::Multiple(v) => v.first().cloned(),
311 })
312 .or(inferred_content_type)
313 .unwrap_or("text/html; charset=utf-8".to_string());
314
315 header_map.insert(
316 hyper::header::CONTENT_TYPE,
317 hyper::header::HeaderValue::from_str(&content_type)?,
318 );
319
320 if use_brotli {
322 header_map.insert(
323 hyper::header::CONTENT_ENCODING,
324 hyper::header::HeaderValue::from_static("br"),
325 );
326 header_map.insert(
327 hyper::header::VARY,
328 hyper::header::HeaderValue::from_static("accept-encoding"),
329 );
330 }
331
332 for (k, v) in &meta.headers {
333 if k.to_lowercase() != "content-type" {
334 let header_name = hyper::header::HeaderName::from_bytes(k.as_bytes())?;
335
336 match v {
337 crate::response::HeaderValue::Single(s) => {
338 let header_value = hyper::header::HeaderValue::from_str(s)?;
339 header_map.insert(header_name, header_value);
340 }
341 crate::response::HeaderValue::Multiple(values) => {
342 for value in values {
343 if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
344 header_map.append(header_name.clone(), header_value);
345 }
346 }
347 }
348 }
349 }
350 }
351
352 *builder.headers_mut().unwrap() = header_map;
353
354 let body = match body {
355 ResponseTransport::Empty => Empty::<Bytes>::new()
356 .map_err(|never| match never {})
357 .boxed(),
358 ResponseTransport::Full(bytes) => {
359 if use_brotli {
360 let compressed = compression::compress_full(&bytes)?;
361 Full::new(Bytes::from(compressed))
362 .map_err(|never| match never {})
363 .boxed()
364 } else {
365 Full::new(bytes.into())
366 .map_err(|never| match never {})
367 .boxed()
368 }
369 }
370 ResponseTransport::Stream(rx) => {
371 if use_brotli {
372 compression::compress_stream(rx)
373 } else {
374 let stream = ReceiverStream::new(rx).map(|data| Ok(Frame::data(Bytes::from(data))));
375 StreamBody::new(stream).boxed()
376 }
377 }
378 };
379
380 Ok(builder.body(body)?)
381}