1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
use std::convert::{Infallible, TryInto}; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; use graphgate_planner::Request; use http::header::HeaderName; use http::HeaderMap; use opentelemetry::trace::{FutureExt, TraceContextExt, Tracer}; use opentelemetry::{global, Context}; use warp::http::Response as HttpResponse; use warp::ws::Ws; use warp::{Filter, Rejection, Reply}; use crate::constants::*; use crate::metrics::METRICS; use crate::{websocket, SharedRouteTable}; use std::time::Instant; #[derive(Clone)] pub struct HandlerConfig { pub shared_route_table: SharedRouteTable, pub forward_headers: Arc<Vec<String>>, } fn do_forward_headers<T: AsRef<str>>( forward_headers: &[T], header_map: &HeaderMap, remote_addr: Option<SocketAddr>, ) -> HeaderMap { let mut new_header_map = HeaderMap::new(); for name in forward_headers { for value in header_map.get_all(name.as_ref()) { if let Ok(name) = HeaderName::from_str(name.as_ref()) { new_header_map.append(name, value.clone()); } } } if let Some(remote_addr) = remote_addr { if let Ok(remote_addr) = remote_addr.to_string().try_into() { new_header_map.append(warp::http::header::FORWARDED, remote_addr); } } new_header_map } pub fn graphql_request( config: HandlerConfig, ) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone { warp::post() .and(warp::body::json()) .and(warp::header::headers_cloned()) .and(warp::addr::remote()) .and_then({ move |request: Request, header_map: HeaderMap, remote_addr: Option<SocketAddr>| { let config = config.clone(); async move { let tracer = global::tracer("graphql"); let query = Context::current_with_span( tracer .span_builder("query") .with_attributes(vec![ KEY_QUERY.string(request.query.clone()), KEY_VARIABLES .string(serde_json::to_string(&request.variables).unwrap()), ]) .start(&tracer), ); let start_time = Instant::now(); let resp = config .shared_route_table .query( request, do_forward_headers(&config.forward_headers, &header_map, remote_addr), ) .with_context(query) .await; METRICS .query_histogram .record((Instant::now() - start_time).as_secs_f64()); METRICS.query_counter.add(1); Ok::<_, Infallible>(resp) } } }) } pub fn graphql_websocket( config: HandlerConfig, ) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone { warp::ws() .and(warp::get()) .and(warp::header::exact_ignore_case("upgrade", "websocket")) .and(warp::header::optional::<String>("sec-websocket-protocol")) .and(warp::header::headers_cloned()) .and(warp::addr::remote()) .map({ move |ws: Ws, protocols: Option<String>, header_map, remote_addr: Option<SocketAddr>| { let config = config.clone(); let protocol = protocols .and_then(|protocols| { protocols .split(',') .find_map(|p| websocket::Protocols::from_str(p.trim()).ok()) }) .unwrap_or(websocket::Protocols::SubscriptionsTransportWS); let header_map = do_forward_headers(&config.forward_headers, &header_map, remote_addr); let reply = ws.on_upgrade(move |websocket| async move { if let Some((composed_schema, route_table)) = config.shared_route_table.get().await { websocket::server( composed_schema, route_table, websocket, protocol, header_map, ) .await; } }); warp::reply::with_header( reply, "Sec-WebSocket-Protocol", protocol.sec_websocket_protocol(), ) } }) } pub fn graphql_playground() -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone { warp::get().map(|| { HttpResponse::builder() .header("content-type", "text/html") .body(include_str!("playground.html")) }) }