1#![allow(missing_docs)] mod error;
3mod event;
4
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::Context;
8use std::task::Poll;
9
10pub use error::ApolloRouterError;
11pub use event::ConfigurationSource;
12pub(crate) use event::Event;
13pub use event::LicenseSource;
14pub(crate) use event::ReloadSource;
15pub use event::SchemaSource;
16pub use event::ShutdownSource;
17use futures::FutureExt;
18#[cfg(test)]
19use futures::channel::mpsc;
20#[cfg(test)]
21use futures::channel::mpsc::SendError;
22use futures::channel::oneshot;
23use futures::prelude::*;
24#[cfg(test)]
25use tokio::sync::Notify;
26use tokio::sync::RwLock;
27use tokio::task::spawn;
28use tracing_futures::WithSubscriber;
29
30use crate::axum_factory::AxumHttpServerFactory;
31use crate::configuration::ListenAddr;
32use crate::orbiter::OrbiterRouterSuperServiceFactory;
33use crate::router_factory::YamlRouterFactory;
34use crate::state_machine::ListenAddresses;
35use crate::state_machine::StateMachine;
36use crate::uplink::UplinkConfig;
37pub struct RouterHttpServer {
74 result: Pin<Box<dyn Future<Output = Result<(), ApolloRouterError>> + Send>>,
75 listen_addresses: Arc<RwLock<ListenAddresses>>,
76 shutdown_sender: Option<oneshot::Sender<()>>,
77}
78
79#[buildstructor::buildstructor]
80impl RouterHttpServer {
81 #[builder(visibility = "pub", entry = "builder", exit = "start")]
124 fn start(
125 schema: SchemaSource,
126 configuration: Option<ConfigurationSource>,
127 license: Option<LicenseSource>,
128 shutdown: Option<ShutdownSource>,
129 uplink: Option<UplinkConfig>,
130 is_telemetry_disabled: Option<bool>,
131 ) -> RouterHttpServer {
132 let (shutdown_sender, shutdown_receiver) = oneshot::channel::<()>();
133 let event_stream = generate_event_stream(
134 shutdown.unwrap_or(ShutdownSource::CtrlC),
135 configuration.unwrap_or_default(),
136 schema,
137 uplink,
138 license.unwrap_or_default(),
139 shutdown_receiver,
140 );
141 let server_factory = AxumHttpServerFactory::new();
142 let router_factory = OrbiterRouterSuperServiceFactory::new(YamlRouterFactory);
143 let state_machine = StateMachine::new(
144 is_telemetry_disabled.unwrap_or(false),
145 server_factory,
146 router_factory,
147 );
148 let listen_addresses = state_machine.listen_addresses.clone();
149 let result = spawn(
150 async move { state_machine.process_events(event_stream).await }
151 .with_current_subscriber(),
152 )
153 .map(|r| match r {
154 Ok(Ok(ok)) => Ok(ok),
155 Ok(Err(err)) => Err(err),
156 Err(err) => {
157 tracing::error!("{}", err);
158 Err(ApolloRouterError::StartupError)
159 }
160 })
161 .with_current_subscriber()
162 .boxed();
163
164 RouterHttpServer {
165 result,
166 shutdown_sender: Some(shutdown_sender),
167 listen_addresses,
168 }
169 }
170
171 pub async fn listen_address(&self) -> Option<ListenAddr> {
178 self.listen_addresses
179 .read()
180 .await
181 .graphql_listen_address
182 .clone()
183 }
184
185 pub async fn extra_listen_adresses(&self) -> Vec<ListenAddr> {
191 self.listen_addresses
192 .read()
193 .await
194 .extra_listen_addresses
195 .clone()
196 }
197
198 pub async fn shutdown(&mut self) -> Result<(), ApolloRouterError> {
200 if let Some(sender) = self.shutdown_sender.take() {
201 let _ = sender.send(());
202 }
203 (&mut self.result).await
204 }
205}
206
207impl Drop for RouterHttpServer {
208 fn drop(&mut self) {
209 if let Some(sender) = self.shutdown_sender.take() {
210 let _ = sender.send(());
211 }
212 }
213}
214
215impl Future for RouterHttpServer {
216 type Output = Result<(), ApolloRouterError>;
217
218 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 self.result.poll_unpin(cx)
220 }
221}
222
223fn generate_event_stream(
227 shutdown: ShutdownSource,
228 configuration: ConfigurationSource,
229 schema: SchemaSource,
230 uplink_config: Option<UplinkConfig>,
231 license: LicenseSource,
232 shutdown_receiver: oneshot::Receiver<()>,
233) -> impl Stream<Item = Event> {
234 let reload_source = ReloadSource::default();
235
236 stream::select_all(vec![
237 shutdown.into_stream().boxed(),
238 schema.into_stream().boxed(),
239 license.into_stream().boxed(),
240 reload_source.clone().into_stream().boxed(),
241 configuration
242 .into_stream(uplink_config)
243 .map(move |config_event| {
244 if let Event::UpdateConfiguration(config) = &config_event {
245 reload_source.set_period(&config.experimental_chaos.force_reload)
246 }
247 config_event
248 })
249 .boxed(),
250 shutdown_receiver
251 .into_stream()
252 .map(|_| Event::Shutdown)
253 .boxed(),
254 ])
255 .take_while(|msg| future::ready(!matches!(msg, Event::Shutdown)))
256 .chain(stream::iter(vec![Event::Shutdown]))
258 .boxed()
259}
260
261#[cfg(test)]
262struct TestRouterHttpServer {
263 router_http_server: RouterHttpServer,
264 event_sender: mpsc::UnboundedSender<Event>,
265 state_machine_update_notifier: Arc<Notify>,
266}
267
268#[cfg(test)]
269impl TestRouterHttpServer {
270 fn new() -> Self {
271 let (event_sender, event_receiver) = mpsc::unbounded();
272 let state_machine_update_notifier = Arc::new(Notify::new());
273
274 let server_factory = AxumHttpServerFactory::new();
275 let router_factory: OrbiterRouterSuperServiceFactory =
276 OrbiterRouterSuperServiceFactory::new(YamlRouterFactory);
277 let state_machine = StateMachine::for_tests(
278 server_factory,
279 router_factory,
280 Arc::clone(&state_machine_update_notifier),
281 );
282
283 let listen_addresses = state_machine.listen_addresses.clone();
284 let result = spawn(
285 async move { state_machine.process_events(event_receiver).await }
286 .with_current_subscriber(),
287 )
288 .map(|r| match r {
289 Ok(Ok(ok)) => Ok(ok),
290 Ok(Err(err)) => Err(err),
291 Err(err) => {
292 tracing::error!("{}", err);
293 Err(ApolloRouterError::StartupError)
294 }
295 })
296 .with_current_subscriber()
297 .boxed();
298
299 TestRouterHttpServer {
300 router_http_server: RouterHttpServer {
301 result,
302 shutdown_sender: None,
303 listen_addresses,
304 },
305 event_sender,
306 state_machine_update_notifier,
307 }
308 }
309
310 async fn request(
311 &self,
312 request: crate::graphql::Request,
313 ) -> Result<crate::graphql::Response, crate::error::FetchError> {
314 Ok(reqwest::Client::new()
315 .post(format!("{}/", self.listen_address().await.unwrap()))
316 .json(&request)
317 .send()
318 .await
319 .expect("couldn't send request")
320 .json()
321 .await
322 .expect("couldn't deserialize into json"))
323 }
324
325 async fn listen_address(&self) -> Option<ListenAddr> {
326 self.router_http_server.listen_address().await
327 }
328
329 async fn send_event(&mut self, event: Event) -> Result<(), SendError> {
330 let result = self.event_sender.send(event).await;
331 self.state_machine_update_notifier.notified().await;
332 result
333 }
334
335 async fn shutdown(mut self) -> Result<(), ApolloRouterError> {
336 self.send_event(Event::Shutdown).await.unwrap();
337 self.router_http_server.shutdown().await
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use std::str::FromStr;
344
345 use serde_json::to_string_pretty;
346
347 use super::*;
348 use crate::Configuration;
349 use crate::graphql;
350 use crate::graphql::Request;
351 use crate::router::Event::UpdateConfiguration;
352 use crate::router::Event::UpdateLicense;
353 use crate::router::Event::UpdateSchema;
354 use crate::uplink::license_enforcement::LicenseState;
355 use crate::uplink::schema::SchemaState;
356
357 fn init_with_server() -> RouterHttpServer {
358 let configuration =
359 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
360 .unwrap();
361 let schema = include_str!("../testdata/supergraph.graphql");
362 RouterHttpServer::builder()
363 .configuration(configuration)
364 .schema(schema)
365 .start()
366 }
367
368 #[tokio::test(flavor = "multi_thread")]
369 async fn basic_request() {
370 let mut router_handle = init_with_server();
371 let listen_address = router_handle
372 .listen_address()
373 .await
374 .expect("router failed to start");
375
376 assert_federated_response(&listen_address, r#"{ topProducts { name } }"#).await;
377 router_handle.shutdown().await.unwrap();
378 }
379
380 async fn assert_federated_response(listen_addr: &ListenAddr, request: &str) {
381 let request = Request::builder().query(request).build();
382 let expected = query(listen_addr, &request).await.unwrap();
383
384 let response = to_string_pretty(&expected).unwrap();
385 assert!(!response.is_empty());
386 }
387
388 async fn query(
389 listen_addr: &ListenAddr,
390 request: &graphql::Request,
391 ) -> Result<graphql::Response, crate::error::FetchError> {
392 Ok(reqwest::Client::new()
393 .post(format!("{listen_addr}/"))
394 .json(request)
395 .send()
396 .await
397 .expect("couldn't send request")
398 .json()
399 .await
400 .expect("couldn't deserialize into json"))
401 }
402
403 #[tokio::test(flavor = "multi_thread")]
404 async fn basic_event_stream_test() {
405 let mut router_handle = TestRouterHttpServer::new();
406
407 let configuration =
408 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
409 .unwrap();
410 let schema = include_str!("../testdata/supergraph.graphql");
411
412 router_handle
414 .send_event(UpdateConfiguration(Arc::new(configuration)))
415 .await
416 .unwrap();
417 router_handle
418 .send_event(UpdateSchema(SchemaState {
419 sdl: schema.to_string(),
420 launch_id: None,
421 }))
422 .await
423 .unwrap();
424 router_handle
425 .send_event(UpdateLicense(LicenseState::Unlicensed))
426 .await
427 .unwrap();
428
429 let request = Request::builder().query(r#"{ me { username } }"#).build();
430
431 let response = router_handle.request(request).await.unwrap();
432 assert_eq!(
433 "@ada",
434 response
435 .data
436 .unwrap()
437 .get("me")
438 .unwrap()
439 .get("username")
440 .unwrap()
441 );
442
443 router_handle
445 .send_event(Event::NoMoreConfiguration)
446 .await
447 .unwrap();
448 router_handle.send_event(Event::NoMoreSchema).await.unwrap();
449 router_handle.send_event(Event::Shutdown).await.unwrap();
450 }
451
452 #[tokio::test(flavor = "multi_thread")]
453 async fn schema_update_test() {
454 let mut router_handle = TestRouterHttpServer::new();
455 router_handle
457 .send_event(UpdateConfiguration(Arc::new(
458 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
459 .unwrap(),
460 )))
461 .await
462 .unwrap();
463 router_handle
464 .send_event(UpdateSchema(SchemaState {
465 sdl: include_str!("../testdata/supergraph_missing_name.graphql").to_string(),
466 launch_id: None,
467 }))
468 .await
469 .unwrap();
470 router_handle
471 .send_event(UpdateLicense(LicenseState::Unlicensed))
472 .await
473 .unwrap();
474
475 let request = Request::builder().query(r#"{ me { username } }"#).build();
477 let response = router_handle.request(request).await.unwrap();
478
479 assert_eq!(
480 "@ada",
481 response
482 .data
483 .unwrap()
484 .get("me")
485 .unwrap()
486 .get("username")
487 .unwrap()
488 );
489
490 let request = Request::builder()
492 .query(r#"{ me { username name } }"#)
493 .build();
494 let response = router_handle.request(request).await.unwrap();
495
496 assert_eq!(
497 r#"Cannot query field "name" on type "User"."#, response.errors[0].message,
498 "{response:?}"
499 );
500 assert_eq!(
501 "GRAPHQL_VALIDATION_FAILED",
502 response.errors[0].extensions.get("code").unwrap()
503 );
504
505 router_handle
507 .send_event(UpdateSchema(SchemaState {
508 sdl: include_str!("../testdata/supergraph.graphql").to_string(),
509 launch_id: None,
510 }))
511 .await
512 .unwrap();
513
514 let request = Request::builder()
516 .query(r#"{ me { username name } }"#)
517 .build();
518
519 let response = router_handle.request(request).await.unwrap();
520
521 assert_eq!(
522 "Ada Lovelace",
523 response
524 .data
525 .unwrap()
526 .get("me")
527 .unwrap()
528 .get("name")
529 .unwrap()
530 );
531
532 router_handle
534 .send_event(UpdateSchema(SchemaState {
535 sdl: include_str!("../testdata/supergraph_missing_name.graphql").to_string(),
536 launch_id: None,
537 }))
538 .await
539 .unwrap();
540
541 let request = Request::builder().query(r#"{ me { username } }"#).build();
542 let response = router_handle.request(request).await.unwrap();
543
544 assert_eq!(
545 "@ada",
546 response
547 .data
548 .unwrap()
549 .get("me")
550 .unwrap()
551 .get("username")
552 .unwrap()
553 );
554
555 let request = Request::builder()
556 .query(r#"{ me { username name } }"#)
557 .build();
558 let response = router_handle.request(request).await.unwrap();
559
560 assert_eq!(
561 r#"Cannot query field "name" on type "User"."#,
562 response.errors[0].message,
563 );
564 assert_eq!(
565 "GRAPHQL_VALIDATION_FAILED",
566 response.errors[0].extensions.get("code").unwrap()
567 );
568 router_handle.shutdown().await.unwrap();
569 }
570}