1#![allow(missing_docs)] mod error;
3mod event;
4
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::Context;
8use std::task::Poll;
9
10pub use error::ApolloRouterError;
11pub use event::ConfigurationSource;
12pub(crate) use event::Event;
13pub use event::LicenseSource;
14pub use event::SchemaSource;
15pub use event::ShutdownSource;
16use futures::FutureExt;
17#[cfg(test)]
18use futures::channel::mpsc;
19#[cfg(test)]
20use futures::channel::mpsc::SendError;
21use futures::channel::oneshot;
22use futures::prelude::*;
23#[cfg(test)]
24use tokio::sync::Notify;
25use tokio::sync::RwLock;
26use tokio::task::spawn;
27use tracing_futures::WithSubscriber;
28
29use crate::axum_factory::AxumHttpServerFactory;
30use crate::configuration::ListenAddr;
31use crate::orbiter::OrbiterRouterSuperServiceFactory;
32use crate::plugins::chaos::ChaosEventStream;
33use crate::router::event::reload::ReloadableEventStream;
34use crate::router_factory::YamlRouterFactory;
35use crate::state_machine::ListenAddresses;
36use crate::state_machine::StateMachine;
37use crate::uplink::UplinkConfig;
38pub struct RouterHttpServer {
75 result: Pin<Box<dyn Future<Output = Result<(), ApolloRouterError>> + Send>>,
76 listen_addresses: Arc<RwLock<ListenAddresses>>,
77 shutdown_sender: Option<oneshot::Sender<()>>,
78}
79
80#[buildstructor::buildstructor]
81impl RouterHttpServer {
82 #[builder(visibility = "pub", entry = "builder", exit = "start")]
125 fn start(
126 schema: SchemaSource,
127 configuration: Option<ConfigurationSource>,
128 license: Option<LicenseSource>,
129 shutdown: Option<ShutdownSource>,
130 uplink: Option<UplinkConfig>,
131 is_telemetry_disabled: Option<bool>,
132 ) -> RouterHttpServer {
133 let (shutdown_sender, shutdown_receiver) = oneshot::channel::<()>();
134 let event_stream = generate_event_stream(
135 shutdown.unwrap_or(ShutdownSource::CtrlC),
136 configuration.unwrap_or_default(),
137 schema,
138 uplink,
139 license.unwrap_or_default(),
140 shutdown_receiver,
141 );
142 let server_factory = AxumHttpServerFactory::new();
143 let router_factory = OrbiterRouterSuperServiceFactory::new(YamlRouterFactory);
144 let state_machine = StateMachine::new(
145 is_telemetry_disabled.unwrap_or(false),
146 server_factory,
147 router_factory,
148 );
149 let listen_addresses = state_machine.listen_addresses.clone();
150 let result = spawn(
151 async move { state_machine.process_events(event_stream).await }
152 .with_current_subscriber(),
153 )
154 .map(|r| match r {
155 Ok(Ok(ok)) => Ok(ok),
156 Ok(Err(err)) => Err(err),
157 Err(err) => {
158 tracing::error!("{}", err);
159 Err(ApolloRouterError::StartupError)
160 }
161 })
162 .with_current_subscriber()
163 .boxed();
164
165 RouterHttpServer {
166 result,
167 shutdown_sender: Some(shutdown_sender),
168 listen_addresses,
169 }
170 }
171
172 pub async fn listen_address(&self) -> Option<ListenAddr> {
179 self.listen_addresses
180 .read()
181 .await
182 .graphql_listen_address
183 .clone()
184 }
185
186 pub async fn extra_listen_adresses(&self) -> Vec<ListenAddr> {
192 self.listen_addresses
193 .read()
194 .await
195 .extra_listen_addresses
196 .clone()
197 }
198
199 pub async fn shutdown(&mut self) -> Result<(), ApolloRouterError> {
201 if let Some(sender) = self.shutdown_sender.take() {
202 let _ = sender.send(());
203 }
204 (&mut self.result).await
205 }
206}
207
208impl Drop for RouterHttpServer {
209 fn drop(&mut self) {
210 if let Some(sender) = self.shutdown_sender.take() {
211 let _ = sender.send(());
212 }
213 }
214}
215
216impl Future for RouterHttpServer {
217 type Output = Result<(), ApolloRouterError>;
218
219 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
220 self.result.poll_unpin(cx)
221 }
222}
223
224fn generate_event_stream(
228 shutdown: ShutdownSource,
229 configuration: ConfigurationSource,
230 schema: SchemaSource,
231 uplink_config: Option<UplinkConfig>,
232 license: LicenseSource,
233 shutdown_receiver: oneshot::Receiver<()>,
234) -> impl Stream<Item = Event> {
235 stream::select_all(vec![
236 shutdown.into_stream().boxed(),
237 schema.into_stream().boxed(),
238 license.into_stream().boxed(),
239 configuration.into_stream(uplink_config).boxed(),
240 shutdown_receiver
241 .into_stream()
242 .map(|_| Event::Shutdown)
243 .boxed(),
244 ])
245 .with_sighup_reload()
246 .with_chaos_reload()
247 .take_while(|msg| future::ready(!matches!(msg, Event::Shutdown)))
248 .chain(stream::iter(vec![Event::Shutdown]))
249 .boxed()
250}
251
252#[cfg(test)]
253struct TestRouterHttpServer {
254 router_http_server: RouterHttpServer,
255 event_sender: mpsc::UnboundedSender<Event>,
256 state_machine_update_notifier: Arc<Notify>,
257}
258
259#[cfg(test)]
260impl TestRouterHttpServer {
261 fn new() -> Self {
262 let (event_sender, event_receiver) = mpsc::unbounded();
263 let state_machine_update_notifier = Arc::new(Notify::new());
264
265 let server_factory = AxumHttpServerFactory::new();
266 let router_factory: OrbiterRouterSuperServiceFactory =
267 OrbiterRouterSuperServiceFactory::new(YamlRouterFactory);
268 let state_machine = StateMachine::for_tests(
269 server_factory,
270 router_factory,
271 Arc::clone(&state_machine_update_notifier),
272 );
273
274 let listen_addresses = state_machine.listen_addresses.clone();
275 let result = spawn(
276 async move { state_machine.process_events(event_receiver).await }
277 .with_current_subscriber(),
278 )
279 .map(|r| match r {
280 Ok(Ok(ok)) => Ok(ok),
281 Ok(Err(err)) => Err(err),
282 Err(err) => {
283 tracing::error!("{}", err);
284 Err(ApolloRouterError::StartupError)
285 }
286 })
287 .with_current_subscriber()
288 .boxed();
289
290 TestRouterHttpServer {
291 router_http_server: RouterHttpServer {
292 result,
293 shutdown_sender: None,
294 listen_addresses,
295 },
296 event_sender,
297 state_machine_update_notifier,
298 }
299 }
300
301 async fn request(
302 &self,
303 request: crate::graphql::Request,
304 ) -> Result<crate::graphql::Response, crate::error::FetchError> {
305 Ok(reqwest::Client::new()
306 .post(format!("{}/", self.listen_address().await.unwrap()))
307 .json(&request)
308 .send()
309 .await
310 .expect("couldn't send request")
311 .json()
312 .await
313 .expect("couldn't deserialize into json"))
314 }
315
316 async fn listen_address(&self) -> Option<ListenAddr> {
317 self.router_http_server.listen_address().await
318 }
319
320 async fn send_event(&mut self, event: Event) -> Result<(), SendError> {
321 let result = self.event_sender.send(event).await;
322 self.state_machine_update_notifier.notified().await;
323 result
324 }
325
326 async fn shutdown(mut self) -> Result<(), ApolloRouterError> {
327 self.send_event(Event::Shutdown).await.unwrap();
328 self.router_http_server.shutdown().await
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use std::str::FromStr;
335
336 use serde_json::to_string_pretty;
337
338 use super::*;
339 use crate::Configuration;
340 use crate::graphql;
341 use crate::graphql::Request;
342 use crate::router::Event::UpdateConfiguration;
343 use crate::router::Event::UpdateLicense;
344 use crate::router::Event::UpdateSchema;
345 use crate::uplink::license_enforcement::LicenseState;
346 use crate::uplink::schema::SchemaState;
347
348 fn init_with_server() -> RouterHttpServer {
349 let configuration =
350 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
351 .unwrap();
352 let schema = include_str!("../testdata/supergraph.graphql");
353 RouterHttpServer::builder()
354 .configuration(configuration)
355 .schema(schema)
356 .start()
357 }
358
359 #[tokio::test(flavor = "multi_thread")]
360 async fn basic_request() {
361 let mut router_handle = init_with_server();
362 let listen_address = router_handle
363 .listen_address()
364 .await
365 .expect("router failed to start");
366
367 assert_federated_response(&listen_address, r#"{ topProducts { name } }"#).await;
368 router_handle.shutdown().await.unwrap();
369 }
370
371 async fn assert_federated_response(listen_addr: &ListenAddr, request: &str) {
372 let request = Request::builder().query(request).build();
373 let expected = query(listen_addr, &request).await.unwrap();
374
375 let response = to_string_pretty(&expected).unwrap();
376 assert!(!response.is_empty());
377 }
378
379 async fn query(
380 listen_addr: &ListenAddr,
381 request: &graphql::Request,
382 ) -> Result<graphql::Response, crate::error::FetchError> {
383 Ok(reqwest::Client::new()
384 .post(format!("{listen_addr}/"))
385 .json(request)
386 .send()
387 .await
388 .expect("couldn't send request")
389 .json()
390 .await
391 .expect("couldn't deserialize into json"))
392 }
393
394 #[tokio::test(flavor = "multi_thread")]
395 async fn basic_event_stream_test() {
396 let mut router_handle = TestRouterHttpServer::new();
397
398 let configuration =
399 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
400 .unwrap();
401 let schema = include_str!("../testdata/supergraph.graphql");
402
403 router_handle
405 .send_event(UpdateConfiguration(Arc::new(configuration)))
406 .await
407 .unwrap();
408 router_handle
409 .send_event(UpdateSchema(SchemaState {
410 sdl: schema.to_string(),
411 launch_id: None,
412 }))
413 .await
414 .unwrap();
415 router_handle
416 .send_event(UpdateLicense(Arc::new(LicenseState::Unlicensed)))
417 .await
418 .unwrap();
419
420 let request = Request::builder().query(r#"{ me { username } }"#).build();
421
422 let response = router_handle.request(request).await.unwrap();
423 assert_eq!(
424 "@ada",
425 response
426 .data
427 .unwrap()
428 .get("me")
429 .unwrap()
430 .get("username")
431 .unwrap()
432 );
433
434 router_handle
436 .send_event(Event::NoMoreConfiguration)
437 .await
438 .unwrap();
439 router_handle.send_event(Event::NoMoreSchema).await.unwrap();
440 router_handle.send_event(Event::Shutdown).await.unwrap();
441 }
442
443 #[tokio::test(flavor = "multi_thread")]
444 async fn schema_update_test() {
445 let mut router_handle = TestRouterHttpServer::new();
446 router_handle
448 .send_event(UpdateConfiguration(Arc::new(
449 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
450 .unwrap(),
451 )))
452 .await
453 .unwrap();
454 router_handle
455 .send_event(UpdateSchema(SchemaState {
456 sdl: include_str!("../testdata/supergraph_missing_name.graphql").to_string(),
457 launch_id: None,
458 }))
459 .await
460 .unwrap();
461 router_handle
462 .send_event(UpdateLicense(Arc::new(LicenseState::Unlicensed)))
463 .await
464 .unwrap();
465
466 let request = Request::builder().query(r#"{ me { username } }"#).build();
468 let response = router_handle.request(request).await.unwrap();
469
470 assert_eq!(
471 "@ada",
472 response
473 .data
474 .unwrap()
475 .get("me")
476 .unwrap()
477 .get("username")
478 .unwrap()
479 );
480
481 let request = Request::builder()
483 .query(r#"{ me { username name } }"#)
484 .build();
485 let response = router_handle.request(request).await.unwrap();
486
487 assert_eq!(
488 r#"Cannot query field "name" on type "User"."#, response.errors[0].message,
489 "{response:?}"
490 );
491 assert_eq!(
492 "GRAPHQL_VALIDATION_FAILED",
493 response.errors[0].extensions.get("code").unwrap()
494 );
495
496 router_handle
498 .send_event(UpdateSchema(SchemaState {
499 sdl: include_str!("../testdata/supergraph.graphql").to_string(),
500 launch_id: None,
501 }))
502 .await
503 .unwrap();
504
505 let request = Request::builder()
507 .query(r#"{ me { username name } }"#)
508 .build();
509
510 let response = router_handle.request(request).await.unwrap();
511
512 assert_eq!(
513 "Ada Lovelace",
514 response
515 .data
516 .unwrap()
517 .get("me")
518 .unwrap()
519 .get("name")
520 .unwrap()
521 );
522
523 router_handle
525 .send_event(UpdateSchema(SchemaState {
526 sdl: include_str!("../testdata/supergraph_missing_name.graphql").to_string(),
527 launch_id: None,
528 }))
529 .await
530 .unwrap();
531
532 let request = Request::builder().query(r#"{ me { username } }"#).build();
533 let response = router_handle.request(request).await.unwrap();
534
535 assert_eq!(
536 "@ada",
537 response
538 .data
539 .unwrap()
540 .get("me")
541 .unwrap()
542 .get("username")
543 .unwrap()
544 );
545
546 let request = Request::builder()
547 .query(r#"{ me { username name } }"#)
548 .build();
549 let response = router_handle.request(request).await.unwrap();
550
551 assert_eq!(
552 r#"Cannot query field "name" on type "User"."#,
553 response.errors[0].message,
554 );
555 assert_eq!(
556 "GRAPHQL_VALIDATION_FAILED",
557 response.errors[0].extensions.get("code").unwrap()
558 );
559 router_handle.shutdown().await.unwrap();
560 }
561}