1#![allow(missing_docs)] #![allow(deprecated)] mod error;
5mod event;
6
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::Context;
10use std::task::Poll;
11
12pub use error::ApolloRouterError;
13pub use event::ConfigurationSource;
14pub(crate) use event::Event;
15pub use event::LicenseSource;
16pub use event::SchemaSource;
17pub use event::ShutdownSource;
18use futures::FutureExt;
19#[cfg(test)]
20use futures::channel::mpsc;
21#[cfg(test)]
22use futures::channel::mpsc::SendError;
23use futures::channel::oneshot;
24use futures::prelude::*;
25#[cfg(test)]
26use tokio::sync::Notify;
27use tokio::sync::RwLock;
28use tokio::task::spawn;
29use tracing_futures::WithSubscriber;
30
31use crate::axum_factory::AxumHttpServerFactory;
32use crate::configuration::ListenAddr;
33use crate::orbiter::OrbiterRouterSuperServiceFactory;
34use crate::plugins::chaos::ChaosEventStream;
35use crate::router::event::reload::ReloadableEventStream;
36use crate::router_factory::YamlRouterFactory;
37use crate::state_machine::ListenAddresses;
38use crate::state_machine::StateMachine;
39use crate::uplink::UplinkConfig;
40pub struct RouterHttpServer {
77 result: Pin<Box<dyn Future<Output = Result<(), ApolloRouterError>> + Send>>,
78 listen_addresses: Arc<RwLock<ListenAddresses>>,
79 shutdown_sender: Option<oneshot::Sender<()>>,
80}
81
82#[buildstructor::buildstructor]
83impl RouterHttpServer {
84 #[builder(visibility = "pub", entry = "builder", exit = "start")]
127 fn start(
128 schema: SchemaSource,
129 configuration: Option<ConfigurationSource>,
130 license: Option<LicenseSource>,
131 shutdown: Option<ShutdownSource>,
132 uplink: Option<UplinkConfig>,
133 is_telemetry_disabled: Option<bool>,
134 ) -> RouterHttpServer {
135 let (shutdown_sender, shutdown_receiver) = oneshot::channel::<()>();
136 let event_stream = generate_event_stream(
137 shutdown.unwrap_or(ShutdownSource::CtrlC),
138 configuration.unwrap_or_default(),
139 schema,
140 uplink,
141 license.unwrap_or_default(),
142 shutdown_receiver,
143 );
144 let server_factory = AxumHttpServerFactory::new();
145 let router_factory = OrbiterRouterSuperServiceFactory::new(YamlRouterFactory);
146 let state_machine = StateMachine::new(
147 is_telemetry_disabled.unwrap_or(false),
148 server_factory,
149 router_factory,
150 );
151 let listen_addresses = state_machine.listen_addresses.clone();
152 let result = spawn(
153 async move { state_machine.process_events(event_stream).await }
154 .with_current_subscriber(),
155 )
156 .map(|r| match r {
157 Ok(Ok(ok)) => Ok(ok),
158 Ok(Err(err)) => Err(err),
159 Err(err) => {
160 tracing::error!("{}", err);
161 Err(ApolloRouterError::StartupError)
162 }
163 })
164 .with_current_subscriber()
165 .boxed();
166
167 RouterHttpServer {
168 result,
169 shutdown_sender: Some(shutdown_sender),
170 listen_addresses,
171 }
172 }
173
174 pub async fn listen_address(&self) -> Option<ListenAddr> {
181 self.listen_addresses
182 .read()
183 .await
184 .graphql_listen_address
185 .clone()
186 }
187
188 pub async fn extra_listen_adresses(&self) -> Vec<ListenAddr> {
194 self.listen_addresses
195 .read()
196 .await
197 .extra_listen_addresses
198 .clone()
199 }
200
201 pub async fn shutdown(&mut self) -> Result<(), ApolloRouterError> {
203 if let Some(sender) = self.shutdown_sender.take() {
204 let _ = sender.send(());
205 }
206 (&mut self.result).await
207 }
208}
209
210impl Drop for RouterHttpServer {
211 fn drop(&mut self) {
212 if let Some(sender) = self.shutdown_sender.take() {
213 let _ = sender.send(());
214 }
215 }
216}
217
218impl Future for RouterHttpServer {
219 type Output = Result<(), ApolloRouterError>;
220
221 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
222 self.result.poll_unpin(cx)
223 }
224}
225
226fn generate_event_stream(
230 shutdown: ShutdownSource,
231 configuration: ConfigurationSource,
232 schema: SchemaSource,
233 uplink_config: Option<UplinkConfig>,
234 license: LicenseSource,
235 shutdown_receiver: oneshot::Receiver<()>,
236) -> impl Stream<Item = Event> {
237 stream::select_all(vec![
238 shutdown.into_stream().boxed(),
239 schema.into_stream().boxed(),
240 license.into_stream().boxed(),
241 configuration.into_stream(uplink_config).boxed(),
242 shutdown_receiver
243 .into_stream()
244 .map(|_| Event::Shutdown)
245 .boxed(),
246 ])
247 .with_sighup_reload()
248 .with_chaos_reload()
249 .take_while(|msg| future::ready(!matches!(msg, Event::Shutdown)))
250 .chain(stream::iter(vec![Event::Shutdown]))
251 .boxed()
252}
253
254#[cfg(test)]
255struct TestRouterHttpServer {
256 router_http_server: RouterHttpServer,
257 event_sender: mpsc::UnboundedSender<Event>,
258 state_machine_update_notifier: Arc<Notify>,
259}
260
261#[cfg(test)]
262impl TestRouterHttpServer {
263 fn new() -> Self {
264 let (event_sender, event_receiver) = mpsc::unbounded();
265 let state_machine_update_notifier = Arc::new(Notify::new());
266
267 let server_factory = AxumHttpServerFactory::new();
268 let router_factory: OrbiterRouterSuperServiceFactory =
269 OrbiterRouterSuperServiceFactory::new(YamlRouterFactory);
270 let state_machine = StateMachine::for_tests(
271 server_factory,
272 router_factory,
273 Arc::clone(&state_machine_update_notifier),
274 );
275
276 let listen_addresses = state_machine.listen_addresses.clone();
277 let result = spawn(
278 async move { state_machine.process_events(event_receiver).await }
279 .with_current_subscriber(),
280 )
281 .map(|r| match r {
282 Ok(Ok(ok)) => Ok(ok),
283 Ok(Err(err)) => Err(err),
284 Err(err) => {
285 tracing::error!("{}", err);
286 Err(ApolloRouterError::StartupError)
287 }
288 })
289 .with_current_subscriber()
290 .boxed();
291
292 TestRouterHttpServer {
293 router_http_server: RouterHttpServer {
294 result,
295 shutdown_sender: None,
296 listen_addresses,
297 },
298 event_sender,
299 state_machine_update_notifier,
300 }
301 }
302
303 async fn request(
304 &self,
305 request: crate::graphql::Request,
306 ) -> Result<crate::graphql::Response, crate::error::FetchError> {
307 Ok(reqwest::Client::new()
308 .post(format!("{}/", self.listen_address().await.unwrap()))
309 .json(&request)
310 .send()
311 .await
312 .expect("couldn't send request")
313 .json()
314 .await
315 .expect("couldn't deserialize into json"))
316 }
317
318 async fn listen_address(&self) -> Option<ListenAddr> {
319 self.router_http_server.listen_address().await
320 }
321
322 async fn send_event(&mut self, event: Event) -> Result<(), SendError> {
323 let result = self.event_sender.send(event).await;
324 self.state_machine_update_notifier.notified().await;
325 result
326 }
327
328 async fn shutdown(mut self) -> Result<(), ApolloRouterError> {
329 self.send_event(Event::Shutdown).await.unwrap();
330 self.router_http_server.shutdown().await
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use std::str::FromStr;
337
338 use serde_json::to_string_pretty;
339
340 use super::*;
341 use crate::Configuration;
342 use crate::graphql;
343 use crate::graphql::Request;
344 use crate::router::Event::UpdateConfiguration;
345 use crate::router::Event::UpdateLicense;
346 use crate::router::Event::UpdateSchema;
347 use crate::uplink::license_enforcement::LicenseState;
348 use crate::uplink::schema::SchemaState;
349
350 fn init_with_server() -> RouterHttpServer {
351 let configuration =
352 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
353 .unwrap();
354 let schema = include_str!("../testdata/supergraph.graphql");
355 RouterHttpServer::builder()
356 .configuration(configuration)
357 .schema(schema)
358 .start()
359 }
360
361 #[tokio::test(flavor = "multi_thread")]
362 async fn basic_request() {
363 let mut router_handle = init_with_server();
364 let listen_address = router_handle
365 .listen_address()
366 .await
367 .expect("router failed to start");
368
369 assert_federated_response(&listen_address, r#"{ topProducts { name } }"#).await;
370 router_handle.shutdown().await.unwrap();
371 }
372
373 async fn assert_federated_response(listen_addr: &ListenAddr, request: &str) {
374 let request = Request::builder().query(request).build();
375 let expected = query(listen_addr, &request).await.unwrap();
376
377 let response = to_string_pretty(&expected).unwrap();
378 assert!(!response.is_empty());
379 }
380
381 async fn query(
382 listen_addr: &ListenAddr,
383 request: &graphql::Request,
384 ) -> Result<graphql::Response, crate::error::FetchError> {
385 Ok(reqwest::Client::new()
386 .post(format!("{listen_addr}/"))
387 .json(request)
388 .send()
389 .await
390 .expect("couldn't send request")
391 .json()
392 .await
393 .expect("couldn't deserialize into json"))
394 }
395
396 #[tokio::test(flavor = "multi_thread")]
397 async fn basic_event_stream_test() {
398 let mut router_handle = TestRouterHttpServer::new();
399
400 let configuration =
401 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
402 .unwrap();
403 let schema = include_str!("../testdata/supergraph.graphql");
404
405 router_handle
407 .send_event(UpdateConfiguration(configuration.into()))
408 .await
409 .unwrap();
410 router_handle
411 .send_event(UpdateSchema(SchemaState {
412 sdl: schema.to_string(),
413 launch_id: None,
414 }))
415 .await
416 .unwrap();
417 router_handle
418 .send_event(UpdateLicense(LicenseState::Unlicensed))
419 .await
420 .unwrap();
421
422 let request = Request::builder().query(r#"{ me { username } }"#).build();
423
424 let response = router_handle.request(request).await.unwrap();
425 assert_eq!(
426 "@ada",
427 response
428 .data
429 .unwrap()
430 .get("me")
431 .unwrap()
432 .get("username")
433 .unwrap()
434 );
435
436 router_handle
438 .send_event(Event::NoMoreConfiguration)
439 .await
440 .unwrap();
441 router_handle.send_event(Event::NoMoreSchema).await.unwrap();
442 router_handle.send_event(Event::Shutdown).await.unwrap();
443 }
444
445 #[tokio::test(flavor = "multi_thread")]
446 async fn schema_update_test() {
447 let mut router_handle = TestRouterHttpServer::new();
448 router_handle
450 .send_event(UpdateConfiguration(Arc::new(
451 Configuration::from_str(include_str!("../testdata/supergraph_config.router.yaml"))
452 .unwrap(),
453 )))
454 .await
455 .unwrap();
456 router_handle
457 .send_event(UpdateSchema(SchemaState {
458 sdl: include_str!("../testdata/supergraph_missing_name.graphql").to_string(),
459 launch_id: None,
460 }))
461 .await
462 .unwrap();
463 router_handle
464 .send_event(UpdateLicense(LicenseState::Unlicensed))
465 .await
466 .unwrap();
467
468 let request = Request::builder().query(r#"{ me { username } }"#).build();
470 let response = router_handle.request(request).await.unwrap();
471
472 assert_eq!(
473 "@ada",
474 response
475 .data
476 .unwrap()
477 .get("me")
478 .unwrap()
479 .get("username")
480 .unwrap()
481 );
482
483 let request = Request::builder()
485 .query(r#"{ me { username name } }"#)
486 .build();
487 let response = router_handle.request(request).await.unwrap();
488
489 assert_eq!(
490 r#"Cannot query field "name" on type "User"."#, response.errors[0].message,
491 "{response:?}"
492 );
493 assert_eq!(
494 "GRAPHQL_VALIDATION_FAILED",
495 response.errors[0].extensions.get("code").unwrap()
496 );
497
498 router_handle
500 .send_event(UpdateSchema(SchemaState {
501 sdl: include_str!("../testdata/supergraph.graphql").to_string(),
502 launch_id: None,
503 }))
504 .await
505 .unwrap();
506
507 let request = Request::builder()
509 .query(r#"{ me { username name } }"#)
510 .build();
511
512 let response = router_handle.request(request).await.unwrap();
513
514 assert_eq!(
515 "Ada Lovelace",
516 response
517 .data
518 .unwrap()
519 .get("me")
520 .unwrap()
521 .get("name")
522 .unwrap()
523 );
524
525 router_handle
527 .send_event(UpdateSchema(SchemaState {
528 sdl: include_str!("../testdata/supergraph_missing_name.graphql").to_string(),
529 launch_id: None,
530 }))
531 .await
532 .unwrap();
533
534 let request = Request::builder().query(r#"{ me { username } }"#).build();
535 let response = router_handle.request(request).await.unwrap();
536
537 assert_eq!(
538 "@ada",
539 response
540 .data
541 .unwrap()
542 .get("me")
543 .unwrap()
544 .get("username")
545 .unwrap()
546 );
547
548 let request = Request::builder()
549 .query(r#"{ me { username name } }"#)
550 .build();
551 let response = router_handle.request(request).await.unwrap();
552
553 assert_eq!(
554 r#"Cannot query field "name" on type "User"."#,
555 response.errors[0].message,
556 );
557 assert_eq!(
558 "GRAPHQL_VALIDATION_FAILED",
559 response.errors[0].extensions.get("code").unwrap()
560 );
561 router_handle.shutdown().await.unwrap();
562 }
563}