apollo_gateway_rs/datasource/
mod.rs

1#![allow(clippy::obfuscated_if_else)]
2
3use std::collections::HashMap;
4use std::ops::Deref;
5use std::pin::Pin;
6use std::sync::Arc;
7use actix::dev::Stream;
8use actix_web::HttpRequest;
9use futures_util::TryFutureExt;
10use http::HeaderMap;
11use once_cell::sync::Lazy;
12use crate::planner::{Response};
13
14/// Represents a connection between your federated gateway and one of your subgraphs.
15pub trait RemoteGraphQLDataSource: Sync + Send + 'static {
16    /// If you have a multiple sources they must have a unique name
17    fn name(&self) -> &str;
18    /// Example countries.trevorblades.com You shouldn`t use http(s)://
19    fn address(&self) -> &str;
20    fn tls(&self) -> bool { false }
21    fn query_path(&self) -> Option<&str> { None }
22    fn subscribe_path(&self) -> Option<&str> { None }
23    fn url_query(&self) -> String {
24        let address = self.address();
25        let protocol = self.tls().then_some("https").unwrap_or("http");
26        let path = self.query_path().unwrap_or("");
27        format!("{protocol}://{address}/{path}")
28    }
29    fn url_subscription(&self) -> String {
30        let address = self.address();
31        let protocol = self.tls().then_some("wss").unwrap_or("ws");
32        let path = self.subscribe_path().unwrap_or("");
33        format!("{protocol}://{address}/{path}")
34    }
35}
36
37use serde::Deserialize;
38use serde_json::Value;
39use crate::Request;
40
41#[derive(Deserialize)]
42pub struct Config<S> {
43    sources: Vec<S>,
44}
45
46/// If you want to load your sources from config you can use DefaultSource. If you not provide tls in your config default value would be false
47#[derive(Deserialize)]
48pub struct DefaultSource {
49    name: String,
50    address: String,
51    #[serde(default = "bool::default")]
52    tls: bool,
53    query_path: Option<String>,
54    subscribe_path: Option<String>,
55}
56
57impl RemoteGraphQLDataSource for DefaultSource {
58    fn name(&self) -> &str {
59        &self.name
60    }
61    fn address(&self) -> &str {
62        &self.address
63    }
64    fn tls(&self) -> bool {
65        self.tls
66    }
67    fn query_path(&self) -> Option<&str> {
68        self.query_path.as_deref()
69    }
70    fn subscribe_path(&self) -> Option<&str> {
71        self.subscribe_path.as_deref()
72    }
73}
74
75impl<S: RemoteGraphQLDataSource> Config<S> {
76    pub fn simple_sources(self) -> HashMap<String, Arc<dyn GraphqlSource>> {
77        self.sources.into_iter()
78            .map(|source| (source.name().to_string(), Arc::new(SimpleSource { source }) as Arc<dyn GraphqlSource>))
79            .collect::<HashMap<String, Arc<dyn GraphqlSource>>>()
80    }
81}
82
83impl<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware> Config<S> {
84    pub fn sources(self) -> HashMap<String, Arc<dyn GraphqlSource>> {
85        self.sources.into_iter()
86            .map(|source| (source.name().to_string(), Arc::new(Source { source }) as Arc<dyn GraphqlSource>))
87            .collect::<HashMap<String, Arc<dyn GraphqlSource>>>()
88    }
89}
90
91
92static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(Default::default);
93
94type SubscriptionStream = Pin<Box<dyn Stream<Item = anyhow::Result<Response>>>>;
95/// Implement GraphqlSourceMiddleware for your source, if you want to modify requests to the subgraph before they're sent and modify response after it.
96#[async_trait::async_trait]
97pub trait GraphqlSourceMiddleware: Send + Sync + 'static + RemoteGraphQLDataSource {
98    /// Override will_send_request to modify your gateway's requests to the subgraph before they're sent.
99    #[allow(unused_variables)]
100    async fn will_send_request(&self, request: &mut HashMap<String, String>, ctx: &Context) -> anyhow::Result<()> {
101        Ok(())
102    }
103    /// Override did_receive_response to modify your gateway's response after request to the subgraph. It will not modify response of subscription.
104    #[allow(unused_variables)]
105    async fn did_receive_response(&self, response: &mut Response, ctx: &Context) -> anyhow::Result<()> {
106        Ok(())
107    }
108
109    #[allow(unused_variables)]
110    async fn on_connection_init(&self, message: &mut Option<Value>, ctx: &Context) -> anyhow::Result<()> {
111        Ok(())
112    }
113
114    async fn fetch(&self, request: Request) -> anyhow::Result<Response> {
115        let url = self.url_query();
116        let headers = HeaderMap::try_from(&request.headers)?;
117        let raw_resp = HTTP_CLIENT
118            .post(&url)
119            .headers(headers)
120            .json(&request.data)
121            .send()
122            .and_then(|res| async move { res.error_for_status() })
123            .await?;
124        let headers = raw_resp.headers().iter()
125            .filter_map(|(name, value)| value.to_str().ok().map(|value| (name.as_str().to_string(), value.to_string())))
126            .collect();
127        let mut resp = raw_resp.json::<Response>().await?;
128        if !resp.errors.is_empty() {
129
130        }
131        resp.headers = headers;
132        Ok(resp)
133    }
134    async fn subscribe(&self, _request: Request) -> SubscriptionStream {
135        unimplemented!()
136    }
137}
138
139impl RemoteGraphQLDataSource for Arc<dyn GraphqlSource> {
140    #[inline]
141    fn name(&self) -> &str {
142        self.deref().name()
143    }
144    #[inline]
145    fn address(&self) -> &str {
146        self.deref().address()
147    }
148    #[inline]
149    fn tls(&self) -> bool {
150        self.deref().tls()
151    }
152    #[inline]
153    fn query_path(&self) -> Option<&str> {
154        self.deref().query_path()
155    }
156    #[inline]
157    fn subscribe_path(&self) -> Option<&str> {
158        self.deref().subscribe_path()
159    }
160    #[inline]
161    fn url_query(&self) -> String {
162        self.deref().url_query()
163    }
164    #[inline]
165    fn url_subscription(&self) -> String {
166        self.deref().url_subscription()
167    }
168}
169
170#[async_trait::async_trait]
171impl GraphqlSourceMiddleware for Arc<dyn GraphqlSource> {
172    async fn will_send_request(&self, request: &mut HashMap<String, String>, ctx: &Context) -> anyhow::Result<()> {
173        self.deref().will_send_request(request, ctx).await
174    }
175    async fn did_receive_response(&self, response: &mut Response, ctx: &Context) -> anyhow::Result<()> {
176        self.deref().did_receive_response(response, ctx).await
177    }
178    async fn on_connection_init(&self, message: &mut Option<Value>, ctx: &Context) -> anyhow::Result<()> {
179        self.deref().on_connection_init(message, ctx).await
180    }
181    async fn fetch(&self, request: Request) -> anyhow::Result<Response> {
182        self.deref().fetch(request).await
183    }
184    async fn subscribe(&self, request: Request) -> SubscriptionStream {
185        self.deref().subscribe(request).await
186    }
187}
188
189impl GraphqlSource for Arc<dyn GraphqlSource> {}
190
191pub trait GraphqlSource: RemoteGraphQLDataSource + GraphqlSourceMiddleware {}
192
193
194pub struct SimpleSource<S: RemoteGraphQLDataSource> {
195    pub(crate) source: S,
196}
197
198impl<S: RemoteGraphQLDataSource> GraphqlSourceMiddleware for SimpleSource<S> {}
199
200impl<S: RemoteGraphQLDataSource> RemoteGraphQLDataSource for SimpleSource<S> {
201    #[inline]
202    fn name(&self) -> &str {
203        self.source.name()
204    }
205    #[inline]
206    fn address(&self) -> &str {
207        self.source.address()
208    }
209    #[inline]
210    fn tls(&self) -> bool {
211        self.source.tls()
212    }
213    #[inline]
214    fn query_path(&self) -> Option<&str> {
215        self.source.query_path()
216    }
217    #[inline]
218    fn subscribe_path(&self) -> Option<&str> {
219        self.source.subscribe_path()
220    }
221    #[inline]
222    fn url_query(&self) -> String {
223        self.source.url_query()
224    }
225    #[inline]
226    fn url_subscription(&self) -> String {
227        self.source.url_subscription()
228    }
229}
230
231impl<S: RemoteGraphQLDataSource> GraphqlSource for SimpleSource<S> {}
232
233pub struct Source<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware> {
234    pub(crate) source: S,
235}
236
237impl<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware> RemoteGraphQLDataSource for Source<S> {
238    #[inline]
239    fn name(&self) -> &str {
240        self.source.name()
241    }
242    #[inline]
243    fn address(&self) -> &str {
244        self.source.address()
245    }
246    #[inline]
247    fn tls(&self) -> bool {
248        self.source.tls()
249    }
250    #[inline]
251    fn query_path(&self) -> Option<&str> {
252        self.source.query_path()
253    }
254    #[inline]
255    fn subscribe_path(&self) -> Option<&str> {
256        self.source.subscribe_path()
257    }
258    #[inline]
259    fn url_query(&self) -> String {
260        self.source.url_query()
261    }
262    #[inline]
263    fn url_subscription(&self) -> String {
264        self.source.url_subscription()
265    }
266}
267
268impl<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware> GraphqlSource for Source<S> {}
269
270#[async_trait::async_trait]
271impl<S: RemoteGraphQLDataSource + GraphqlSourceMiddleware> GraphqlSourceMiddleware for Source<S> {
272    async fn will_send_request(&self, request: &mut HashMap<String, String>, ctx: &Context) -> anyhow::Result<()> {
273        self.source.will_send_request(request, ctx).await
274    }
275    async fn on_connection_init(&self, message: &mut Option<Value>, ctx: &Context) -> anyhow::Result<()> {
276        self.source.on_connection_init(message, ctx).await
277    }
278    async fn did_receive_response(&self, response: &mut Response, ctx: &Context) -> anyhow::Result<()> {
279        self.source.did_receive_response(response, ctx).await
280    }
281    async fn fetch(&self, request: Request) -> anyhow::Result<Response> {
282        self.source.fetch(request).await
283    }
284    async fn subscribe(&self, request: Request) -> SubscriptionStream {
285        self.source.subscribe(request).await
286    }
287}
288/// Context give you access to request data like headers, app_data and extensions.
289pub struct Context(HttpRequest);
290
291impl Context {
292    pub fn new(request: HttpRequest) -> Self {
293        Self(request)
294    }
295}
296
297impl Deref for Context {
298    type Target = HttpRequest;
299    fn deref(&self) -> &Self::Target {
300        &self.0
301    }
302}
303
304unsafe impl Send for Context {}
305
306unsafe impl Sync for Context {}
307