graphgate_handler/
handler.rs

1use std::convert::{Infallible, TryInto};
2use std::net::SocketAddr;
3use std::str::FromStr;
4use std::sync::Arc;
5
6use graphgate_planner::Request;
7use http::header::HeaderName;
8use http::HeaderMap;
9use opentelemetry::trace::{FutureExt, TraceContextExt, Tracer};
10use opentelemetry::{global, Context};
11use warp::http::Response as HttpResponse;
12use warp::ws::Ws;
13use warp::{Filter, Rejection, Reply};
14
15use crate::constants::*;
16use crate::metrics::METRICS;
17use crate::{websocket, SharedRouteTable};
18use std::time::Instant;
19
20#[derive(Clone)]
21pub struct HandlerConfig {
22    pub shared_route_table: SharedRouteTable,
23    pub forward_headers: Arc<Vec<String>>,
24}
25
26fn do_forward_headers<T: AsRef<str>>(
27    forward_headers: &[T],
28    header_map: &HeaderMap,
29    remote_addr: Option<SocketAddr>,
30) -> HeaderMap {
31    let mut new_header_map = HeaderMap::new();
32    for name in forward_headers {
33        for value in header_map.get_all(name.as_ref()) {
34            if let Ok(name) = HeaderName::from_str(name.as_ref()) {
35                new_header_map.append(name, value.clone());
36            }
37        }
38    }
39    if let Some(remote_addr) = remote_addr {
40        if let Ok(remote_addr) = remote_addr.to_string().try_into() {
41            new_header_map.append(warp::http::header::FORWARDED, remote_addr);
42        }
43    }
44    new_header_map
45}
46
47pub fn graphql_request(
48    config: HandlerConfig,
49) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
50    warp::post()
51        .and(warp::body::json())
52        .and(warp::header::headers_cloned())
53        .and(warp::addr::remote())
54        .and_then({
55            move |request: Request, header_map: HeaderMap, remote_addr: Option<SocketAddr>| {
56                let config = config.clone();
57                async move {
58                    let tracer = global::tracer("graphql");
59
60                    let query = Context::current_with_span(
61                        tracer
62                            .span_builder("query")
63                            .with_attributes(vec![
64                                KEY_QUERY.string(request.query.clone()),
65                                KEY_VARIABLES
66                                    .string(serde_json::to_string(&request.variables).unwrap()),
67                            ])
68                            .start(&tracer),
69                    );
70
71                    let start_time = Instant::now();
72                    let resp = config
73                        .shared_route_table
74                        .query(
75                            request,
76                            do_forward_headers(&config.forward_headers, &header_map, remote_addr),
77                        )
78                        .with_context(query)
79                        .await;
80
81                    METRICS
82                        .query_histogram
83                        .record((Instant::now() - start_time).as_secs_f64());
84                    METRICS.query_counter.add(1);
85
86                    Ok::<_, Infallible>(resp)
87                }
88            }
89        })
90}
91
92pub fn graphql_websocket(
93    config: HandlerConfig,
94) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
95    warp::ws()
96        .and(warp::get())
97        .and(warp::header::exact_ignore_case("upgrade", "websocket"))
98        .and(warp::header::optional::<String>("sec-websocket-protocol"))
99        .and(warp::header::headers_cloned())
100        .and(warp::addr::remote())
101        .map({
102            move |ws: Ws, protocols: Option<String>, header_map, remote_addr: Option<SocketAddr>| {
103                let config = config.clone();
104                let protocol = protocols
105                    .and_then(|protocols| {
106                        protocols
107                            .split(',')
108                            .find_map(|p| websocket::Protocols::from_str(p.trim()).ok())
109                    })
110                    .unwrap_or(websocket::Protocols::SubscriptionsTransportWS);
111                let header_map =
112                    do_forward_headers(&config.forward_headers, &header_map, remote_addr);
113
114                let reply = ws.on_upgrade(move |websocket| async move {
115                    if let Some((composed_schema, route_table)) =
116                        config.shared_route_table.get().await
117                    {
118                        websocket::server(
119                            composed_schema,
120                            route_table,
121                            websocket,
122                            protocol,
123                            header_map,
124                        )
125                        .await;
126                    }
127                });
128
129                warp::reply::with_header(
130                    reply,
131                    "Sec-WebSocket-Protocol",
132                    protocol.sec_websocket_protocol(),
133                )
134            }
135        })
136}
137
138pub fn graphql_playground() -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
139    warp::get().map(|| {
140        HttpResponse::builder()
141            .header("content-type", "text/html")
142            .body(include_str!("playground.html"))
143    })
144}