Skip to main content

graphix_package_http/
lib.rs

1#![doc(
2    html_logo_url = "https://graphix-lang.github.io/graphix/graphix-icon.svg",
3    html_favicon_url = "https://graphix-lang.github.io/graphix/graphix-icon.svg"
4)]
5use anyhow::{bail, Result};
6use arcstr::{literal, ArcStr};
7use bytes::Bytes;
8use compact_str::format_compact;
9use futures::{channel::mpsc, SinkExt};
10use graphix_compiler::{
11    errf,
12    expr::ExprId,
13    node::genn,
14    typ::{FnType, Type},
15    Apply, BindId, BuiltIn, CustomBuiltinType, Event, ExecCtx, LambdaId, Node, Rt, Scope,
16    UserEvent, CBATCH_POOL,
17};
18use graphix_package_core::{
19    CachedArgs, CachedArgsAsync, CachedVals, EvalCached, EvalCachedAsync,
20};
21use graphix_rt::GXRt;
22use netidx_value::{
23    abstract_type::AbstractWrapper, Abstract, FromValue, PBytes, ValArray, Value,
24};
25use std::{
26    any::Any,
27    cmp::Ordering,
28    collections::VecDeque,
29    fmt,
30    hash::{Hash, Hasher},
31    pin::Pin,
32    sync::{Arc, LazyLock},
33    task::{Context, Poll},
34    time::Duration,
35};
36use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
37
38// ── Abstract ClientValue ─────────────────────────────────────────
39
40#[derive(Debug, Clone)]
41struct ClientValue {
42    client: Arc<reqwest::Client>,
43}
44
45impl PartialEq for ClientValue {
46    fn eq(&self, other: &Self) -> bool {
47        Arc::ptr_eq(&self.client, &other.client)
48    }
49}
50
51impl Eq for ClientValue {}
52
53impl PartialOrd for ClientValue {
54    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
55        Some(self.cmp(other))
56    }
57}
58
59impl Ord for ClientValue {
60    fn cmp(&self, other: &Self) -> Ordering {
61        Arc::as_ptr(&self.client).cmp(&Arc::as_ptr(&other.client))
62    }
63}
64
65impl Hash for ClientValue {
66    fn hash<H: Hasher>(&self, state: &mut H) {
67        Arc::as_ptr(&self.client).hash(state)
68    }
69}
70
71graphix_package_core::impl_no_pack!(ClientValue);
72
73static CLIENT_WRAPPER: LazyLock<AbstractWrapper<ClientValue>> = LazyLock::new(|| {
74    let id = uuid::Uuid::from_bytes([
75        0xc7, 0xd8, 0xe9, 0xfa, 0x0b, 0x1c, 0x4d, 0x2e, 0x3f, 0x40, 0x51, 0x62, 0x73,
76        0x84, 0x95, 0xa6,
77    ]);
78    Abstract::register::<ClientValue>(id).expect("failed to register ClientValue")
79});
80
81fn get_client(cached: &CachedVals, idx: usize) -> Option<Arc<reqwest::Client>> {
82    match cached.0.get(idx)?.as_ref()? {
83        Value::Abstract(a) => {
84            let cv = a.downcast_ref::<ClientValue>()?;
85            Some(cv.client.clone())
86        }
87        _ => None,
88    }
89}
90
91// ── Abstract ServerValue ─────────────────────────────────────────
92
93#[derive(Debug)]
94struct ServerHandle {
95    abort: tokio::task::AbortHandle,
96    addr: std::net::SocketAddr,
97}
98
99impl Drop for ServerHandle {
100    fn drop(&mut self) {
101        self.abort.abort();
102    }
103}
104
105#[derive(Debug, Clone)]
106struct ServerValue {
107    handle: Arc<ServerHandle>,
108}
109
110impl PartialEq for ServerValue {
111    fn eq(&self, other: &Self) -> bool {
112        Arc::ptr_eq(&self.handle, &other.handle)
113    }
114}
115
116impl Eq for ServerValue {}
117
118impl PartialOrd for ServerValue {
119    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
120        Some(self.cmp(other))
121    }
122}
123
124impl Ord for ServerValue {
125    fn cmp(&self, other: &Self) -> Ordering {
126        Arc::as_ptr(&self.handle).cmp(&Arc::as_ptr(&other.handle))
127    }
128}
129
130impl Hash for ServerValue {
131    fn hash<H: Hasher>(&self, state: &mut H) {
132        Arc::as_ptr(&self.handle).hash(state)
133    }
134}
135
136graphix_package_core::impl_no_pack!(ServerValue);
137
138static SERVER_WRAPPER: LazyLock<AbstractWrapper<ServerValue>> = LazyLock::new(|| {
139    let id = uuid::Uuid::from_bytes([
140        0xd7, 0xe8, 0xf9, 0x0a, 0x1b, 0x2c, 0x4d, 0x3e, 0x4f, 0x50, 0x61, 0x72, 0x83,
141        0x94, 0xa5, 0xb6,
142    ]);
143    Abstract::register::<ServerValue>(id).expect("failed to register ServerValue")
144});
145
146// ── Shared helpers ───────────────────────────────────────────────
147
148fn value_to_header_map(v: &Value) -> reqwest::header::HeaderMap {
149    let mut map = reqwest::header::HeaderMap::new();
150    if let Value::Array(arr) = v {
151        for pair in arr.iter() {
152            if let Value::Array(p) = pair {
153                if p.len() == 2 {
154                    if let (Value::String(k), Value::String(v)) = (&p[0], &p[1]) {
155                        if let (Ok(name), Ok(val)) = (
156                            reqwest::header::HeaderName::from_bytes(k.as_bytes()),
157                            reqwest::header::HeaderValue::from_str(v),
158                        ) {
159                            map.append(name, val);
160                        }
161                    }
162                }
163            }
164        }
165    }
166    map
167}
168
169fn headers_to_value<'a>(
170    iter: impl Iterator<
171        Item = (&'a hyper::header::HeaderName, &'a hyper::header::HeaderValue),
172    >,
173) -> Value {
174    let v: Vec<Value> = iter
175        .map(|(k, v)| {
176            Value::Array(ValArray::from([
177                Value::String(ArcStr::from(k.as_str())),
178                Value::String(ArcStr::from(v.to_str().unwrap_or(""))),
179            ]))
180        })
181        .collect();
182    Value::Array(ValArray::from(v))
183}
184
185fn build_response(body: ArcStr, headers: Value, status: u16, url: ArcStr) -> Value {
186    let r: [(ArcStr, Value); 4] = [
187        (literal!("body"), Value::String(body)),
188        (literal!("headers"), headers),
189        (literal!("status"), Value::U16(status)),
190        (literal!("url"), Value::String(url)),
191    ];
192    r.into()
193}
194
195fn build_bin_response(body: Bytes, headers: Value, status: u16, url: ArcStr) -> Value {
196    let r: [(ArcStr, Value); 4] = [
197        (literal!("body"), Value::Bytes(PBytes::new(body))),
198        (literal!("headers"), headers),
199        (literal!("status"), Value::U16(status)),
200        (literal!("url"), Value::String(url)),
201    ];
202    r.into()
203}
204
205fn parse_method(s: &str) -> std::result::Result<reqwest::Method, String> {
206    match s {
207        "GET" => Ok(reqwest::Method::GET),
208        "POST" => Ok(reqwest::Method::POST),
209        "PUT" => Ok(reqwest::Method::PUT),
210        "DELETE" => Ok(reqwest::Method::DELETE),
211        "PATCH" => Ok(reqwest::Method::PATCH),
212        "HEAD" => Ok(reqwest::Method::HEAD),
213        "OPTIONS" => Ok(reqwest::Method::OPTIONS),
214        other => Err(format!("unknown HTTP method: {other}")),
215    }
216}
217
218static DEFAULT_CLIENT: LazyLock<Arc<reqwest::Client>> = LazyLock::new(|| {
219    Arc::new(
220        reqwest::Client::builder().build().expect("failed to create default HTTP client"),
221    )
222});
223
224// ── HttpClient ───────────────────────────────────────────────────
225
226#[derive(Debug, Default)]
227pub(crate) struct HttpClientEv;
228
229impl<R: Rt, E: UserEvent> EvalCached<R, E> for HttpClientEv {
230    const NAME: &str = "http_client";
231    const NEEDS_CALLSITE: bool = false;
232
233    fn eval(&mut self, _ctx: &mut ExecCtx<R, E>, cached: &CachedVals) -> Option<Value> {
234        let timeout = cached.get::<Option<Duration>>(0)?;
235        let default_headers = cached.0.get(1)?.as_ref()?.clone();
236        let redirect_limit = cached.get::<u32>(2)?;
237        let ca_cert = cached.get::<Option<Bytes>>(3)?;
238        let _ = cached.0.get(4)?.as_ref()?;
239        let mut builder = reqwest::Client::builder();
240        if let Some(timeout) = timeout {
241            builder = builder.timeout(timeout);
242        }
243        builder =
244            builder.redirect(reqwest::redirect::Policy::limited(redirect_limit as usize));
245        let headers = value_to_header_map(&default_headers);
246        if !headers.is_empty() {
247            builder = builder.default_headers(headers);
248        }
249        if let Some(ca_cert) = &ca_cert {
250            let cert = match reqwest::Certificate::from_pem(ca_cert) {
251                Ok(c) => c,
252                Err(e) => return Some(errf!("HTTPError", "invalid ca_cert PEM: {e}")),
253            };
254            builder = builder.add_root_certificate(cert);
255        }
256        Some(match builder.build() {
257            Ok(client) => CLIENT_WRAPPER.wrap(ClientValue { client: Arc::new(client) }),
258            Err(e) => errf!("HTTPError", "failed to build client: {e}"),
259        })
260    }
261}
262
263pub(crate) type HttpClient = CachedArgs<HttpClientEv>;
264
265// ── HttpDefaultClient ────────────────────────────────────────────
266
267#[derive(Debug, Default)]
268pub(crate) struct HttpDefaultClientEv;
269
270impl<R: Rt, E: UserEvent> EvalCached<R, E> for HttpDefaultClientEv {
271    const NAME: &str = "http_default_client";
272    const NEEDS_CALLSITE: bool = false;
273
274    fn eval(&mut self, _ctx: &mut ExecCtx<R, E>, cached: &CachedVals) -> Option<Value> {
275        cached.0.get(0)?.as_ref()?;
276        Some(CLIENT_WRAPPER.wrap(ClientValue { client: DEFAULT_CLIENT.clone() }))
277    }
278}
279
280pub(crate) type HttpDefaultClient = CachedArgs<HttpDefaultClientEv>;
281
282// ── HttpServerAddr ──────────────────────────────────────────────
283
284#[derive(Debug, Default)]
285pub(crate) struct HttpServerAddrEv;
286
287impl<R: Rt, E: UserEvent> EvalCached<R, E> for HttpServerAddrEv {
288    const NAME: &str = "http_server_addr";
289    const NEEDS_CALLSITE: bool = false;
290
291    fn eval(&mut self, _ctx: &mut ExecCtx<R, E>, cached: &CachedVals) -> Option<Value> {
292        let v = cached.0.get(0)?.as_ref()?;
293        match v {
294            Value::Abstract(a) => {
295                let sv = a.downcast_ref::<ServerValue>()?;
296                Some(Value::String(ArcStr::from(sv.handle.addr.to_string().as_str())))
297            }
298            _ => None,
299        }
300    }
301}
302
303pub(crate) type HttpServerAddr = CachedArgs<HttpServerAddrEv>;
304
305// ── HttpRequest / HttpRequestBin ─────────────────────────────────
306
307#[derive(Debug)]
308pub(crate) struct RequestArgs<B> {
309    method: ArcStr,
310    headers: Value,
311    body: Option<B>,
312    timeout: Option<Duration>,
313    client: Arc<reqwest::Client>,
314    url: ArcStr,
315}
316
317fn prepare_request_args<B: FromValue>(cached: &CachedVals) -> Option<RequestArgs<B>> {
318    let method = cached.get::<ArcStr>(0)?;
319    let headers = cached.0.get(1)?.as_ref()?.clone();
320    let body = cached.get::<Option<B>>(2)?;
321    let timeout = cached.get::<Option<Duration>>(3)?;
322    let client = get_client(cached, 4)?;
323    let url = cached.get::<ArcStr>(5)?;
324    Some(RequestArgs { method, headers, body, timeout, client, url })
325}
326
327async fn send_request(
328    method: &str,
329    client: &reqwest::Client,
330    url: &str,
331    headers: &Value,
332    body: Option<reqwest::Body>,
333    timeout: Option<Duration>,
334) -> std::result::Result<reqwest::Response, Value> {
335    let method = parse_method(method).map_err(|e| errf!("HTTPError", "{e}"))?;
336    let mut req = client.request(method, url);
337    let hdrs = value_to_header_map(headers);
338    if !hdrs.is_empty() {
339        req = req.headers(hdrs);
340    }
341    if let Some(body) = body {
342        req = req.body(body);
343    }
344    if let Some(timeout) = timeout {
345        req = req.timeout(timeout);
346    }
347    req.send().await.map_err(|e| errf!("HTTPError", "request failed: {e}"))
348}
349
350#[derive(Debug, Default)]
351pub(crate) struct HttpRequestEv;
352
353impl EvalCachedAsync for HttpRequestEv {
354    const NAME: &str = "http_request";
355    const NEEDS_CALLSITE: bool = false;
356    type Args = RequestArgs<ArcStr>;
357
358    fn prepare_args(&mut self, cached: &CachedVals) -> Option<Self::Args> {
359        prepare_request_args(cached)
360    }
361
362    fn eval(args: Self::Args) -> impl Future<Output = Value> + Send {
363        async move {
364            let resp = match send_request(
365                &args.method,
366                &args.client,
367                &args.url,
368                &args.headers,
369                args.body.map(|s| reqwest::Body::from(s.to_string())),
370                args.timeout,
371            )
372            .await
373            {
374                Ok(r) => r,
375                Err(e) => return e,
376            };
377            let status = resp.status().as_u16();
378            let url = ArcStr::from(resp.url().as_str());
379            let hdrs = headers_to_value(resp.headers().iter());
380            match resp.text().await {
381                Ok(body) => {
382                    build_response(ArcStr::from(body.as_str()), hdrs, status, url)
383                }
384                Err(e) => errf!("HTTPError", "failed to read body: {e}"),
385            }
386        }
387    }
388}
389
390pub(crate) type HttpRequest = CachedArgsAsync<HttpRequestEv>;
391
392#[derive(Debug, Default)]
393pub(crate) struct HttpRequestBinEv;
394
395impl EvalCachedAsync for HttpRequestBinEv {
396    const NAME: &str = "http_request_bin";
397    const NEEDS_CALLSITE: bool = false;
398    type Args = RequestArgs<Bytes>;
399
400    fn prepare_args(&mut self, cached: &CachedVals) -> Option<Self::Args> {
401        prepare_request_args(cached)
402    }
403
404    fn eval(args: Self::Args) -> impl Future<Output = Value> + Send {
405        async move {
406            let resp = match send_request(
407                &args.method,
408                &args.client,
409                &args.url,
410                &args.headers,
411                args.body.map(reqwest::Body::from),
412                args.timeout,
413            )
414            .await
415            {
416                Ok(r) => r,
417                Err(e) => return e,
418            };
419            let status = resp.status().as_u16();
420            let url = ArcStr::from(resp.url().as_str());
421            let hdrs = headers_to_value(resp.headers().iter());
422            match resp.bytes().await {
423                Ok(body) => build_bin_response(body, hdrs, status, url),
424                Err(e) => errf!("HTTPError", "failed to read body: {e}"),
425            }
426        }
427    }
428}
429
430pub(crate) type HttpRequestBin = CachedArgsAsync<HttpRequestBinEv>;
431
432// ── HttpServe (server) ───────────────────────────────────────────
433
434struct HttpReqEvent {
435    request: Value,
436    reply: Option<tokio::sync::oneshot::Sender<Value>>,
437}
438
439impl fmt::Debug for HttpReqEvent {
440    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441        f.debug_struct("HttpReqEvent")
442            .field("request", &self.request)
443            .field("reply", &self.reply.is_some())
444            .finish()
445    }
446}
447
448impl CustomBuiltinType for HttpReqEvent {}
449
450fn build_request_value(
451    body: Option<ArcStr>,
452    headers: Value,
453    method: ArcStr,
454    path: ArcStr,
455    query: Option<ArcStr>,
456) -> Value {
457    let r: [(ArcStr, Value); 5] = [
458        (literal!("body"), body.map(Value::String).unwrap_or(Value::Null)),
459        (literal!("headers"), headers),
460        (literal!("method"), Value::String(method)),
461        (literal!("path"), Value::String(path)),
462        (literal!("query"), query.map(Value::String).unwrap_or(Value::Null)),
463    ];
464    r.into()
465}
466
467fn struct_field(v: &Value, idx: usize) -> Option<&Value> {
468    match v {
469        Value::Array(arr) => match arr.get(idx)? {
470            Value::Array(pair) if pair.len() == 2 => Some(&pair[1]),
471            _ => None,
472        },
473        _ => None,
474    }
475}
476
477fn build_hyper_response(
478    v: &Value,
479) -> std::result::Result<
480    hyper::Response<http_body_util::Full<Bytes>>,
481    std::convert::Infallible,
482> {
483    if let Value::Error(e) = v {
484        return Ok(hyper::Response::builder()
485            .status(500)
486            .body(http_body_util::Full::new(Bytes::from(format!("{e}"))))
487            .unwrap());
488    }
489    // Response fields (alphabetical): body(0), headers(1), status(2), url(3)
490    let body = match struct_field(v, 0) {
491        Some(Value::String(s)) => Bytes::from(s.to_string()),
492        _ => Bytes::new(),
493    };
494    let status = match struct_field(v, 2) {
495        Some(Value::U16(s)) => *s,
496        _ => 200,
497    };
498    let mut response = hyper::Response::builder().status(status);
499    if let Some(Value::Array(hdrs)) = struct_field(v, 1) {
500        for h in hdrs.iter() {
501            if let Value::Array(pair) = h {
502                if pair.len() == 2 {
503                    if let (Value::String(k), Value::String(v)) = (&pair[0], &pair[1]) {
504                        response = response.header(&**k, &**v);
505                    }
506                }
507            }
508        }
509    }
510    Ok(response.body(http_body_util::Full::new(body)).unwrap())
511}
512
513async fn handle_http_request(
514    req: hyper::Request<hyper::body::Incoming>,
515    mut tx: mpsc::Sender<
516        poolshark::global::GPooled<Vec<(BindId, Box<dyn CustomBuiltinType>)>>,
517    >,
518    id: BindId,
519) -> std::result::Result<
520    hyper::Response<http_body_util::Full<Bytes>>,
521    std::convert::Infallible,
522> {
523    use http_body_util::BodyExt;
524    let (parts, body) = req.into_parts();
525    let body_bytes = match body.collect().await {
526        Ok(b) => b.to_bytes(),
527        Err(_) => Bytes::new(),
528    };
529    let method = ArcStr::from(parts.method.as_str());
530    let path = ArcStr::from(parts.uri.path());
531    let query = parts.uri.query().map(ArcStr::from);
532    let hdrs = headers_to_value(parts.headers.iter());
533    let body_str = if body_bytes.is_empty() {
534        None
535    } else {
536        match std::str::from_utf8(&body_bytes) {
537            Ok(s) => Some(ArcStr::from(s)),
538            Err(_) => None,
539        }
540    };
541    let request_value = build_request_value(body_str, hdrs, method, path, query);
542    let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
543    let mut batch = CBATCH_POOL.take();
544    batch.push((
545        id,
546        Box::new(HttpReqEvent { request: request_value, reply: Some(reply_tx) })
547            as Box<dyn CustomBuiltinType>,
548    ));
549    if tx.send(batch).await.is_err() {
550        return Ok(hyper::Response::builder()
551            .status(503)
552            .body(http_body_util::Full::new(Bytes::from("Service Unavailable")))
553            .unwrap());
554    }
555    match reply_rx.await {
556        Ok(resp_value) => build_hyper_response(&resp_value),
557        Err(_) => Ok(hyper::Response::builder()
558            .status(500)
559            .body(http_body_util::Full::new(Bytes::from("Internal Server Error")))
560            .unwrap()),
561    }
562}
563
564fn build_tls_acceptor(
565    cert_pem: &[u8],
566    key_pem: &[u8],
567) -> std::result::Result<tokio_rustls::TlsAcceptor, Value> {
568    let certs: Vec<_> = rustls_pemfile::certs(&mut &*cert_pem)
569        .collect::<std::result::Result<_, _>>()
570        .map_err(|e| errf!("HTTPError", "invalid cert PEM: {e}"))?;
571    let key = rustls_pemfile::private_key(&mut &*key_pem)
572        .map_err(|e| errf!("HTTPError", "invalid key PEM: {e}"))?
573        .ok_or_else(|| errf!("HTTPError", "no private key found in key PEM"))?;
574    let config = rustls::ServerConfig::builder()
575        .with_no_client_auth()
576        .with_single_cert(certs, key)
577        .map_err(|e| errf!("HTTPError", "TLS config error: {e}"))?;
578    Ok(tokio_rustls::TlsAcceptor::from(Arc::new(config)))
579}
580
581enum MaybeTls {
582    Plain(tokio::net::TcpStream),
583    Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
584}
585
586impl AsyncRead for MaybeTls {
587    fn poll_read(
588        self: Pin<&mut Self>,
589        cx: &mut Context<'_>,
590        buf: &mut ReadBuf<'_>,
591    ) -> Poll<std::io::Result<()>> {
592        match self.get_mut() {
593            MaybeTls::Plain(s) => Pin::new(s).poll_read(cx, buf),
594            MaybeTls::Tls(s) => Pin::new(s).poll_read(cx, buf),
595        }
596    }
597}
598
599impl AsyncWrite for MaybeTls {
600    fn poll_write(
601        self: Pin<&mut Self>,
602        cx: &mut Context<'_>,
603        buf: &[u8],
604    ) -> Poll<std::io::Result<usize>> {
605        match self.get_mut() {
606            MaybeTls::Plain(s) => Pin::new(s).poll_write(cx, buf),
607            MaybeTls::Tls(s) => Pin::new(s).poll_write(cx, buf),
608        }
609    }
610
611    fn poll_flush(
612        self: Pin<&mut Self>,
613        cx: &mut Context<'_>,
614    ) -> Poll<std::io::Result<()>> {
615        match self.get_mut() {
616            MaybeTls::Plain(s) => Pin::new(s).poll_flush(cx),
617            MaybeTls::Tls(s) => Pin::new(s).poll_flush(cx),
618        }
619    }
620
621    fn poll_shutdown(
622        self: Pin<&mut Self>,
623        cx: &mut Context<'_>,
624    ) -> Poll<std::io::Result<()>> {
625        match self.get_mut() {
626            MaybeTls::Plain(s) => Pin::new(s).poll_shutdown(cx),
627            MaybeTls::Tls(s) => Pin::new(s).poll_shutdown(cx),
628        }
629    }
630}
631
632async fn serve_loop(
633    listener: tokio::net::TcpListener,
634    tls: Option<tokio_rustls::TlsAcceptor>,
635    tx: mpsc::Sender<
636        poolshark::global::GPooled<Vec<(BindId, Box<dyn CustomBuiltinType>)>>,
637    >,
638    id: BindId,
639    max_connections: Arc<tokio::sync::Semaphore>,
640) {
641    loop {
642        let permit = match max_connections.clone().acquire_owned().await {
643            Ok(p) => p,
644            Err(_) => return, // semaphore closed
645        };
646        let (stream, _) = match listener.accept().await {
647            Ok(conn) => conn,
648            Err(e) => {
649                log::error!("HTTP accept error: {e}");
650                continue;
651            }
652        };
653        let io = match &tls {
654            None => MaybeTls::Plain(stream),
655            Some(acceptor) => match acceptor.accept(stream).await {
656                Ok(tls_stream) => MaybeTls::Tls(tls_stream),
657                Err(e) => {
658                    log::error!("TLS handshake error: {e}");
659                    continue;
660                }
661            },
662        };
663        let io = hyper_util::rt::TokioIo::new(io);
664        let tx = tx.clone();
665        tokio::spawn(async move {
666            let _permit = permit;
667            let service = hyper::service::service_fn(|req| {
668                handle_http_request(req, tx.clone(), id)
669            });
670            if let Err(e) = hyper::server::conn::http1::Builder::new()
671                .serve_connection(io, service)
672                .await
673            {
674                log::error!("HTTP connection error: {e}");
675            }
676        });
677    }
678}
679
680#[derive(Debug)]
681pub(crate) struct HttpServe<R: Rt, E: UserEvent> {
682    args: CachedVals,
683    id: BindId,
684    top_id: ExprId,
685    handler: Node<R, E>,
686    pid: BindId,
687    x: BindId,
688    queue: VecDeque<(Value, Option<tokio::sync::oneshot::Sender<Value>>)>,
689    ready: bool,
690    abort: Option<tokio::task::AbortHandle>,
691}
692
693impl<R: Rt, E: UserEvent> BuiltIn<R, E> for HttpServe<R, E> {
694    const NAME: &str = "http_serve";
695    const NEEDS_CALLSITE: bool = false;
696
697    fn init<'a, 'b, 'c, 'd>(
698        ctx: &'a mut ExecCtx<R, E>,
699        typ: &'a graphix_compiler::typ::FnType,
700        resolved: Option<&'d FnType>,
701        scope: &'b Scope,
702        from: &'c [Node<R, E>],
703        top_id: ExprId,
704    ) -> Result<Box<dyn Apply<R, E>>> {
705        match from {
706            [_, _, _, _, _] => {
707                let typ = resolved.unwrap_or(typ);
708                let scope =
709                    scope.append(&format_compact!("fn{}", LambdaId::new().inner()));
710                let id = BindId::new();
711                ctx.rt.ref_var(id, top_id);
712                let pid = BindId::new();
713                let mftyp = match &typ.args[4].typ {
714                    Type::Fn(ft) => ft.clone(),
715                    t => bail!("expected a function not {t}"),
716                };
717                let (x, xn) = genn::bind(
718                    ctx,
719                    &scope.lexical,
720                    "x",
721                    mftyp.args[0].typ.clone(),
722                    top_id,
723                );
724                let fnode = genn::reference(ctx, pid, Type::Fn(mftyp.clone()), top_id);
725                let handler = genn::apply(fnode, scope, vec![xn], &mftyp, top_id);
726                Ok(Box::new(HttpServe {
727                    args: CachedVals::new(from),
728                    id,
729                    top_id,
730                    handler,
731                    pid,
732                    x,
733                    queue: VecDeque::new(),
734                    ready: true,
735                    abort: None,
736                }))
737            }
738            _ => bail!("expected five arguments"),
739        }
740    }
741}
742
743impl<R: Rt, E: UserEvent> Apply<R, E> for HttpServe<R, E> {
744    fn update(
745        &mut self,
746        ctx: &mut ExecCtx<R, E>,
747        from: &mut [Node<R, E>],
748        event: &mut Event<E>,
749    ) -> Option<Value> {
750        let mut changed = [false; 5];
751        self.args.update_diff(&mut changed, ctx, from, event);
752        // update handler function reference
753        if changed[4] {
754            if let Some(v) = self.args.0[4].clone() {
755                ctx.cached.insert(self.pid, v.clone());
756                event.variables.insert(self.pid, v);
757            }
758        }
759        // start/restart server when addr/cert/key/max_connections changes
760        let mut server_result = None;
761        if changed[0] || changed[1] || changed[2] || changed[3] {
762            if let Some(abort) = self.abort.take() {
763                abort.abort();
764            }
765            if let Some(Value::String(addr)) = &self.args.0[0] {
766                // build TLS acceptor if cert and key are provided
767                let tls = match (&self.args.0[1], &self.args.0[2]) {
768                    (Some(Value::Bytes(cert)), Some(Value::Bytes(key))) => {
769                        match build_tls_acceptor(cert, key) {
770                            Ok(a) => Some(a),
771                            Err(e) => return Some(e),
772                        }
773                    }
774                    (Some(Value::Null), Some(Value::Null))
775                    | (None, None)
776                    | (Some(Value::Null), None)
777                    | (None, Some(Value::Null)) => None,
778                    _ => {
779                        return Some(errf!(
780                            "HTTPError",
781                            "both cert and key must be provided for TLS"
782                        ))
783                    }
784                };
785                let max_conn = match &self.args.0[3] {
786                    Some(Value::I64(n)) if *n > 0 => *n as usize,
787                    Some(Value::I64(n)) => {
788                        return Some(errf!(
789                            "HTTPError",
790                            "max_connections must be > 0, got {n}"
791                        ))
792                    }
793                    _ => 768,
794                };
795                let std_listener = match std::net::TcpListener::bind(&**addr) {
796                    Ok(l) => l,
797                    Err(e) => {
798                        return Some(errf!("HTTPError", "bind to {addr} failed: {e}"))
799                    }
800                };
801                let bound_addr = match std_listener.local_addr() {
802                    Ok(a) => a,
803                    Err(e) => return Some(errf!("HTTPError", "local_addr failed: {e}")),
804                };
805                if let Err(e) = std_listener.set_nonblocking(true) {
806                    return Some(errf!("HTTPError", "set_nonblocking failed: {e}"));
807                }
808                let listener = match tokio::net::TcpListener::from_std(std_listener) {
809                    Ok(l) => l,
810                    Err(e) => {
811                        return Some(errf!("HTTPError", "tokio listener failed: {e}"))
812                    }
813                };
814                let (tx, rx) = mpsc::channel(100);
815                ctx.rt.watch(rx);
816                let id = self.id;
817                let semaphore = Arc::new(tokio::sync::Semaphore::new(max_conn));
818                let handle = tokio::spawn(serve_loop(listener, tls, tx, id, semaphore));
819                let abort = handle.abort_handle();
820                self.abort = Some(abort.clone());
821                server_result = Some(SERVER_WRAPPER.wrap(ServerValue {
822                    handle: Arc::new(ServerHandle { abort, addr: bound_addr }),
823                }));
824            }
825        }
826        // receive incoming requests from the server
827        if let Some(mut cbt) = event.custom.remove(&self.id) {
828            if let Some(req) = (&mut *cbt as &mut dyn Any).downcast_mut::<HttpReqEvent>()
829            {
830                let request = req.request.clone();
831                let reply = req.reply.take();
832                self.queue.push_back((request, reply));
833            }
834        }
835        // set up first queued request for handler processing
836        if self.ready && !self.queue.is_empty() {
837            if let Some((req, _)) = self.queue.front() {
838                self.ready = false;
839                ctx.cached.insert(self.x, req.clone());
840                event.variables.insert(self.x, req.clone());
841            }
842        }
843        // process handler responses
844        loop {
845            match self.handler.update(ctx, event) {
846                None => break,
847                Some(v) => {
848                    self.ready = true;
849                    if let Some((_, reply)) = self.queue.pop_front() {
850                        if let Some(reply) = reply {
851                            let _ = reply.send(v);
852                        }
853                    }
854                    match self.queue.front() {
855                        Some((req, _)) => {
856                            self.ready = false;
857                            ctx.cached.insert(self.x, req.clone());
858                            event.variables.insert(self.x, req.clone());
859                        }
860                        None => break,
861                    }
862                }
863            }
864        }
865        server_result
866    }
867
868    fn typecheck(
869        &mut self,
870        ctx: &mut ExecCtx<R, E>,
871        _from: &mut [Node<R, E>],
872        _phase: graphix_compiler::TypecheckPhase<'_>,
873    ) -> Result<()> {
874        self.handler.typecheck(ctx)?;
875        Ok(())
876    }
877
878    fn refs(&self, refs: &mut graphix_compiler::Refs) {
879        self.handler.refs(refs)
880    }
881
882    fn delete(&mut self, ctx: &mut ExecCtx<R, E>) {
883        ctx.rt.unref_var(self.id, self.top_id);
884        if let Some(abort) = self.abort.take() {
885            abort.abort();
886        }
887        ctx.cached.remove(&self.x);
888        ctx.env.unbind_variable(self.x);
889        ctx.cached.remove(&self.pid);
890        self.handler.delete(ctx);
891    }
892
893    fn sleep(&mut self, ctx: &mut ExecCtx<R, E>) {
894        ctx.rt.unref_var(self.id, self.top_id);
895        self.id = BindId::new();
896        ctx.rt.ref_var(self.id, self.top_id);
897        if let Some(abort) = self.abort.take() {
898            abort.abort();
899        }
900        self.args.clear();
901        self.queue.clear();
902        self.ready = true;
903        self.handler.sleep(ctx);
904    }
905}
906
907graphix_derive::defpackage! {
908    builtins => [
909        HttpClient,
910        HttpDefaultClient,
911        HttpServerAddr,
912        HttpRequest,
913        HttpRequestBin,
914        HttpServe as HttpServe<GXRt<X>, X::UserEvent>,
915    ],
916}