apollo_gateway_rs/
lib.rs

1#[forbid(clippy::unwrap_used)]
2#[forbid(clippy::panicking_unwrap)]
3#[forbid(clippy::unnecessary_unwrap)]
4#[forbid(clippy::unwrap_in_result)]
5mod datasource;
6mod handler;
7mod planner;
8mod schema;
9mod validation;
10
11use std::cell::Cell;
12use std::collections::HashMap;
13use std::fs::File;
14use std::io::BufReader;
15use std::marker::PhantomData;
16use std::sync::Arc;
17use serde::Deserialize;
18pub use crate::datasource::{RemoteGraphQLDataSource, Context, GraphqlSourceMiddleware, DefaultSource};
19use crate::datasource::{Config, GraphqlSource, SimpleSource, Source};
20pub use crate::planner::{Response, Request};
21use crate::handler::{ServiceRouteTable, SharedRouteTable};
22
23#[derive(Default)]
24pub struct GatewayServerBuilder {
25    table: HashMap<String, Arc<dyn GraphqlSource>>,
26    limit: Option<usize>,
27    // Compile time check, because someone can don't use build() and push Data<GatewayServerBuilder> instead of Data<GatewayServer> to state of app
28    _marker: PhantomData<Cell<()>>,
29}
30
31impl GatewayServerBuilder {
32    pub fn with_limit_recursive_depth(mut self, limit: usize) -> GatewayServerBuilder {
33        self.limit = Some(limit);
34        self
35    }
36    /// Append sources. Make sure that all sources have unique name
37    pub fn with_sources<S: RemoteGraphQLDataSource>(mut self, sources: impl Iterator<Item=S>) -> GatewayServerBuilder {
38        let sources = sources
39            .map(|source| (source.name().to_string(), Arc::new(SimpleSource { source }) as Arc<dyn GraphqlSource>))
40            .collect::<HashMap<String, Arc<dyn GraphqlSource>>>();
41        self.table.extend(sources);
42        self
43    }
44    /// Append sources with middleware extension. Make sure that all sources have unique name
45    pub fn with_middleware_sources<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware>(mut self, sources: impl Iterator<Item=S>) -> GatewayServerBuilder {
46        let sources = sources
47            .map(|source| (source.name().to_string(), Arc::new(Source { source }) as Arc<dyn GraphqlSource>))
48            .collect::<HashMap<String, Arc<dyn GraphqlSource>>>();
49        self.table.extend(sources);
50        self
51    }
52    /// Append source. Make sure that all sources have unique name
53    pub fn with_source<S: RemoteGraphQLDataSource>(mut self, source: S) -> GatewayServerBuilder {
54        let name = source.name().to_owned();
55        let source = Arc::new(SimpleSource { source });
56        self.table.insert(name, source);
57        self
58    }
59    /// Append source with middleware extension. Make sure that all sources have unique name
60    pub fn with_middleware_source<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware>(mut self, source: S) -> GatewayServerBuilder {
61        let name = source.name().to_owned();
62        let source: Arc<dyn GraphqlSource> = Arc::new(Source { source });
63        self.table.insert(name, source);
64        self
65    }
66    fn from_json<S>(path: &str) -> anyhow::Result<Config<S>> where for<'de> S: Deserialize<'de> {
67        let file = File::open(path)?;
68        let reader = BufReader::new(file);
69        let config = serde_json::from_reader::<_, Config<S>>(reader)?;
70        Ok(config)
71    }
72    /// Append sources from json config for example
73    /// ```json
74    /// {
75    ///     "sources": [
76    ///         {
77    ///             name: "your-source-name",
78    ///             address: "your-source-address",
79    ///         }
80    ///     ]
81    ///
82    /// }
83    /// ```
84    /// Make sure that all sources have unique name
85    pub fn with_sources_from_json<S: RemoteGraphQLDataSource>(mut self, path: &str) -> anyhow::Result<GatewayServerBuilder> where for<'de> S: Deserialize<'de> {
86        let config = Self::from_json::<S>(path)?;
87        let sources = config.simple_sources();
88        self.table.extend(sources);
89        Ok(self)
90    }
91    /// Append sources with middleware extension from json config for example
92    /// ```json
93    /// {
94    ///     "sources": [
95    ///         {
96    ///             name: "your-source-name",
97    ///             address: "your-source-address",
98    ///         }
99    ///     ]
100    ///
101    /// }
102    /// ```
103    /// Make sure that all sources have unique name
104    pub fn with_middleware_sources_from_json<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware>(mut self, path: &str) -> anyhow::Result<GatewayServerBuilder> where for<'de> S: Deserialize<'de> {
105        let config = Self::from_json::<S>(path)?;
106        let sources = config.sources();
107        self.table.extend(sources);
108        Ok(self)
109    }
110
111    /// Build a Gateway-Server. After building gateway-server will try to parse a schema from your remote sources.
112    pub fn build(self) -> GatewayServer {
113        let table = ServiceRouteTable::from(self.table);
114        let shared_route_table = SharedRouteTable::default();
115        shared_route_table.set_route_table(table);
116        GatewayServer {
117            table: shared_route_table,
118            limit: self.limit
119        }
120    }
121}
122
123/// Gateway-server will parse a schema from your remote sources, fetch request and make subscription. Don't forget to pass it into app_data. See example:
124/// ```rust
125/// async fn main() -> std::io::Result<()> {
126///     use actix_web::{App, HttpServer, web::Data};
127///     use apollo_gateway_rs::GatewayServer;
128///     let gateway_server = GatewayServer::builder()
129///         .with_source(CommonSource::new("countries", "countries.trevorblades.com", true))
130///         .build();
131///     let gateway_server = Data::new(gateway_server);
132///     HttpServer::new(move || App::new()
133///         .app_data(gateway_server.clone())
134///         .configure(configure_api)
135///     )
136///         .bind("0.0.0.0:3000")?
137///         .run()
138///         .await
139/// }
140/// ```
141pub struct GatewayServer {
142    table: SharedRouteTable<Arc<dyn GraphqlSource>>,
143    limit: Option<usize>
144}
145
146impl GatewayServer {
147    /// Create a builder for server
148    pub fn builder() -> GatewayServerBuilder {
149        GatewayServerBuilder::default()
150    }
151}
152
153pub mod actix {
154    use std::str::FromStr;
155    use std::sync::Arc;
156    use actix_web::http::header::SEC_WEBSOCKET_PROTOCOL;
157    use actix_web::HttpResponse;
158    use k8s_openapi::serde_json;
159    use opentelemetry::trace::{FutureExt, TraceContextExt, Tracer};
160    use crate::{Context, GatewayServer};
161    use crate::handler::constants::{KEY_QUERY, KEY_VARIABLES};
162    use crate::handler::{Protocols, Subscription};
163    use crate::planner::RequestData;
164
165    /// Request handler
166    pub async fn graphql_request(
167        server: actix_web::web::Data<GatewayServer>,
168        request: actix_web::web::Json<RequestData>,
169        req: actix_web::HttpRequest,
170    ) -> HttpResponse {
171        let request = request.into_inner();
172        let ctx = Context::new(req);
173        let tracer = opentelemetry::global::tracer("graphql");
174        let query = opentelemetry::Context::current_with_span(
175            tracer
176                .span_builder("query")
177                .with_attributes(vec![
178                    KEY_QUERY.string(request.query.clone()),
179                    KEY_VARIABLES.string(serde_json::to_string(&request.variables).unwrap()),
180                ])
181                .start(&tracer),
182        );
183        server.table.query(request, ctx, server.limit).with_context(query).await
184    }
185
186    /// Subscription handler
187    pub async fn graphql_subscription(
188        server: actix_web::web::Data<GatewayServer>,
189        req: actix_web::HttpRequest,
190        payload: actix_web::web::Payload,
191    ) -> HttpResponse {
192        let ctx = Arc::new(Context::new(req.clone()));
193        let protocols = req.headers().get(SEC_WEBSOCKET_PROTOCOL).and_then(|header| header.to_str().ok());
194        let protocol = protocols
195            .and_then(|protocols| {
196                protocols.split(',').find_map(|p| Protocols::from_str(p.trim()).ok())
197            })
198            .unwrap_or(Protocols::SubscriptionsTransportWS);
199        if let Some((composed_schema, route_table)) = server.table.get().await {
200            let protocols = [protocol.sec_websocket_protocol()];
201            let subscription = Subscription::new(composed_schema, route_table, ctx, protocol);
202            return match actix_web_actors::ws::WsResponseBuilder::new(subscription, &req, payload)
203                .protocols(&protocols)
204                .start() {
205                Ok(r) => r,
206                Err(e) => HttpResponse::InternalServerError().body(e.to_string())
207            };
208        }
209        HttpResponse::InternalServerError().finish()
210    }
211}
212
213
214