apollo_router/
batching.rs

1//! Various utility functions and core structures used to implement batching support within
2//! the router.
3
4use std::collections::HashMap;
5use std::collections::HashSet;
6use std::fmt;
7use std::sync::Arc;
8use std::sync::atomic::AtomicUsize;
9use std::sync::atomic::Ordering;
10
11use opentelemetry::Context as otelContext;
12use opentelemetry::trace::TraceContextExt;
13use parking_lot::Mutex as PMutex;
14use tokio::sync::Mutex;
15use tokio::sync::mpsc;
16use tokio::sync::oneshot;
17use tokio::task::JoinHandle;
18use tower::BoxError;
19use tracing::Instrument;
20use tracing::Span;
21
22use crate::Context;
23use crate::error::FetchError;
24use crate::error::SubgraphBatchingError;
25use crate::graphql;
26use crate::plugins::telemetry::otel::span_ext::OpenTelemetrySpanExt;
27use crate::services::SubgraphRequest;
28use crate::services::SubgraphResponse;
29use crate::services::http::HttpClientServiceFactory;
30use crate::services::process_batches;
31use crate::services::router;
32use crate::services::router::body::RouterBody;
33use crate::services::subgraph::SubgraphRequestId;
34use crate::spec::QueryHash;
35
36/// A query that is part of a batch.
37/// Note: It's ok to make transient clones of this struct, but *do not* store clones anywhere apart
38/// from the single copy in the extensions. The batching co-ordinator relies on the fact that all
39/// senders are dropped to know when to finish processing.
40#[derive(Clone, Debug)]
41pub(crate) struct BatchQuery {
42    /// The index of this query relative to the entire batch
43    index: usize,
44
45    /// A channel sender for sending updates to the entire batch
46    sender: Arc<Mutex<Option<mpsc::Sender<BatchHandlerMessage>>>>,
47
48    /// How many more progress updates are we expecting to send?
49    remaining: Arc<AtomicUsize>,
50
51    /// Batch to which this BatchQuery belongs
52    batch: Arc<Batch>,
53}
54
55impl fmt::Display for BatchQuery {
56    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57        write!(f, "index: {}, ", self.index)?;
58        write!(f, "remaining: {}, ", self.remaining.load(Ordering::Acquire))?;
59        write!(f, "sender: {:?}, ", self.sender)?;
60        write!(f, "batch: {:?}, ", self.batch)?;
61        Ok(())
62    }
63}
64
65impl BatchQuery {
66    /// Is this BatchQuery finished?
67    pub(crate) fn finished(&self) -> bool {
68        self.remaining.load(Ordering::Acquire) == 0
69    }
70
71    /// Inform the batch of query hashes representing fetches needed by this element of the batch query
72    pub(crate) async fn set_query_hashes(
73        &self,
74        query_hashes: Vec<Arc<QueryHash>>,
75    ) -> Result<(), BoxError> {
76        self.remaining.store(query_hashes.len(), Ordering::Release);
77
78        self.sender
79            .lock()
80            .await
81            .as_ref()
82            .ok_or(SubgraphBatchingError::SenderUnavailable)?
83            .send(BatchHandlerMessage::Begin {
84                index: self.index,
85                query_hashes,
86            })
87            .await?;
88        Ok(())
89    }
90
91    /// Signal to the batch handler that this specific batch query has made some progress.
92    ///
93    /// The returned channel can be awaited to receive the GraphQL response, when ready.
94    pub(crate) async fn signal_progress(
95        &self,
96        client_factory: HttpClientServiceFactory,
97        request: SubgraphRequest,
98        gql_request: graphql::Request,
99    ) -> Result<oneshot::Receiver<Result<SubgraphResponse, BoxError>>, BoxError> {
100        // Create a receiver for this query so that it can eventually get the request meant for it
101        let (tx, rx) = oneshot::channel();
102
103        tracing::debug!(
104            "index: {}, REMAINING: {}",
105            self.index,
106            self.remaining.load(Ordering::Acquire)
107        );
108        self.sender
109            .lock()
110            .await
111            .as_ref()
112            .ok_or(SubgraphBatchingError::SenderUnavailable)?
113            .send(BatchHandlerMessage::Progress(Box::new(
114                BatchHandlerMessageProgress {
115                    index: self.index,
116                    client_factory,
117                    request,
118                    gql_request,
119                    response_sender: tx,
120                    span_context: Span::current().context(),
121                },
122            )))
123            .await?;
124
125        if !self.finished() {
126            self.remaining.fetch_sub(1, Ordering::AcqRel);
127        }
128
129        // May now be finished
130        if self.finished() {
131            let mut sender = self.sender.lock().await;
132            *sender = None;
133        }
134
135        Ok(rx)
136    }
137
138    /// Signal to the batch handler that this specific batch query is cancelled
139    pub(crate) async fn signal_cancelled(&self, reason: String) -> Result<(), BoxError> {
140        self.sender
141            .lock()
142            .await
143            .as_ref()
144            .ok_or(SubgraphBatchingError::SenderUnavailable)?
145            .send(BatchHandlerMessage::Cancel {
146                index: self.index,
147                reason,
148            })
149            .await?;
150
151        if !self.finished() {
152            self.remaining.fetch_sub(1, Ordering::AcqRel);
153        }
154
155        // May now be finished
156        if self.finished() {
157            let mut sender = self.sender.lock().await;
158            *sender = None;
159        }
160
161        Ok(())
162    }
163}
164
165// #[derive(Debug)]
166enum BatchHandlerMessage {
167    /// Cancel one of the batch items
168    Cancel {
169        index: usize,
170        reason: String,
171    },
172
173    Progress(Box<BatchHandlerMessageProgress>),
174
175    /// A query has passed query planning and knows how many fetches are needed
176    /// to complete.
177    Begin {
178        index: usize,
179        query_hashes: Vec<Arc<QueryHash>>,
180    },
181}
182
183/// A query has reached the subgraph service and we should update its state
184struct BatchHandlerMessageProgress {
185    index: usize,
186    client_factory: HttpClientServiceFactory,
187    request: SubgraphRequest,
188    gql_request: graphql::Request,
189    response_sender: oneshot::Sender<Result<SubgraphResponse, BoxError>>,
190    span_context: otelContext,
191}
192
193/// Collection of info needed to resolve a batch query
194pub(crate) struct BatchQueryInfo {
195    /// The owning subgraph request
196    request: SubgraphRequest,
197
198    /// The GraphQL request tied to this subgraph request
199    gql_request: graphql::Request,
200
201    /// Notifier for the subgraph service handler
202    ///
203    /// Note: This must be used or else the subgraph request will time out
204    sender: oneshot::Sender<Result<SubgraphResponse, BoxError>>,
205}
206
207// TODO: Do we want to generate a UUID for a batch for observability reasons?
208// TODO: Do we want to track the size of a batch?
209#[derive(Debug)]
210pub(crate) struct Batch {
211    /// A sender channel to communicate with the batching handler
212    senders: PMutex<Vec<Option<mpsc::Sender<BatchHandlerMessage>>>>,
213
214    /// The spawned batching handler task handle
215    ///
216    /// Note: We keep this as a failsafe. If the task doesn't terminate _before_ the batch is
217    /// dropped, then we will abort() the task on drop.
218    spawn_handle: JoinHandle<Result<(), BoxError>>,
219
220    /// What is the size (number of input operations) of the batch?
221    #[allow(dead_code)]
222    size: usize,
223}
224
225impl Batch {
226    /// Creates a new batch, spawning an async task for handling updates to the
227    /// batch lifecycle.
228    pub(crate) fn spawn_handler(size: usize) -> Self {
229        tracing::debug!("New batch created with size {size}");
230
231        // Create the message channel pair for sending update events to the spawned task
232        let (spawn_tx, mut rx) = mpsc::channel(size);
233
234        // Populate Senders
235        let mut senders = vec![];
236
237        for _ in 0..size {
238            senders.push(Some(spawn_tx.clone()));
239        }
240
241        let spawn_handle = tokio::spawn(async move {
242            /// Helper struct for keeping track of the state of each individual BatchQuery
243            ///
244            #[derive(Debug)]
245            struct BatchQueryState {
246                registered: HashSet<Arc<QueryHash>>,
247                committed: HashSet<Arc<QueryHash>>,
248                cancelled: HashSet<Arc<QueryHash>>,
249            }
250
251            impl BatchQueryState {
252                // We are ready when everything we registered is in either cancelled or
253                // committed.
254                fn is_ready(&self) -> bool {
255                    self.registered.difference(&self.committed.union(&self.cancelled).cloned().collect()).collect::<Vec<_>>().is_empty()
256                }
257            }
258
259            // Progressively track the state of the various batch fetches that we expect to see. Keys are batch
260            // indices.
261            let mut batch_state: HashMap<usize, BatchQueryState> = HashMap::with_capacity(size);
262
263            // We also need to keep track of all requests we need to make and their send handles
264            let mut requests: Vec<Vec<BatchQueryInfo>> =
265                Vec::from_iter((0..size).map(|_| Vec::new()));
266
267            let mut master_client_factory = None;
268            tracing::debug!("Batch about to await messages...");
269            // Start handling messages from various portions of the request lifecycle
270            // When recv() returns None, we want to stop processing messages
271            while let Some(msg) = rx.recv().await {
272                match msg {
273                    BatchHandlerMessage::Cancel { index, reason } => {
274                        // Log the reason for cancelling, update the state
275                        tracing::debug!("Cancelling index: {index}, {reason}");
276
277                        if let Some(state) = batch_state.get_mut(&index) {
278                            // Short-circuit any requests that are waiting for this cancelled request to complete.
279                            let cancelled_requests = std::mem::take(&mut requests[index]);
280                            for BatchQueryInfo {
281                                request, sender, ..
282                            } in cancelled_requests
283                            {
284                                let subgraph_name = request.subgraph_name;
285                                if let Err(log_error) = sender.send(Err(Box::new(FetchError::SubrequestBatchingError {
286                                        service: subgraph_name.clone(),
287                                        reason: format!("request cancelled: {reason}"),
288                                    }))) {
289                                    tracing::error!(service=subgraph_name, error=?log_error, "failed to notify waiter that request is cancelled");
290                                }
291                            }
292
293                            // Clear out everything that has committed, now that they are cancelled, and
294                            // mark everything as having been cancelled.
295                            state.committed.clear();
296                            state.cancelled = state.registered.clone();
297                        }
298                    }
299
300                    BatchHandlerMessage::Begin {
301                        index,
302                        query_hashes,
303                    } => {
304                        tracing::debug!("Beginning batch for index {index} with {query_hashes:?}");
305
306                        batch_state.insert(
307                            index,
308                            BatchQueryState {
309                                cancelled: HashSet::with_capacity(query_hashes.len()),
310                                committed: HashSet::with_capacity(query_hashes.len()),
311                                registered: HashSet::from_iter(query_hashes),
312                            },
313                        );
314                    }
315
316                    BatchHandlerMessage::Progress(progress) => {
317                        // Progress the index
318                        let BatchHandlerMessageProgress {
319                            index,
320                            client_factory,
321                            request,
322                            gql_request,
323                            response_sender,
324                            span_context,
325                        } = *progress;
326
327                        tracing::debug!("Progress index: {index}");
328
329                        if let Some(state) = batch_state.get_mut(&index) {
330                            state.committed.insert(request.query_hash.clone());
331                        }
332
333                        if master_client_factory.is_none() {
334                            master_client_factory = Some(client_factory);
335                        }
336                        Span::current().add_link(span_context.span().span_context().clone());
337                        requests[index].push(BatchQueryInfo {
338                            request,
339                            gql_request,
340                            sender: response_sender,
341                        })
342                    }
343                }
344            }
345
346            // Make sure that we are actually ready and haven't forgotten to update something somewhere
347            if batch_state.values().any(|f| !f.is_ready()) {
348                tracing::error!("All senders for the batch have dropped before reaching the ready state: {batch_state:#?}");
349                // There's not much else we can do, so perform an early return
350                return Err(SubgraphBatchingError::ProcessingFailed("batch senders not ready when required".to_string()).into());
351            }
352
353            tracing::debug!("Assembling {size} requests into batches");
354
355            // We now have a bunch of requests which are organised by index and we would like to
356            // convert them into a bunch of requests organised by service...
357
358            let all_in_one: Vec<_> = requests.into_iter().flatten().collect();
359
360            // Now build up a Service oriented view to use in constructing our batches
361            let mut svc_map: HashMap<String, Vec<BatchQueryInfo>> = HashMap::new();
362            for BatchQueryInfo {
363                request: sg_request,
364                gql_request,
365                sender: tx,
366            } in all_in_one
367            {
368                let subgraph_name = sg_request.subgraph_name.clone();
369                let value = svc_map
370                    .entry(
371                        subgraph_name,
372                    )
373                    .or_default();
374                value.push(BatchQueryInfo {
375                    request: sg_request,
376                    gql_request,
377                    sender: tx,
378                });
379            }
380
381            // If we don't have a master_client_factory, we can't do anything.
382            if let Some(client_factory) = master_client_factory {
383                process_batches(client_factory, svc_map).await?;
384            }
385            Ok(())
386        }.instrument(tracing::info_span!("batch_request", size)));
387
388        Self {
389            senders: PMutex::new(senders),
390            spawn_handle,
391            size,
392        }
393    }
394
395    /// Create a batch query for a specific index in this batch
396    ///
397    /// This function may fail if the index doesn't exist or has already been taken
398    pub(crate) fn query_for_index(
399        batch: Arc<Batch>,
400        index: usize,
401    ) -> Result<BatchQuery, SubgraphBatchingError> {
402        let mut guard = batch.senders.lock();
403        // It's a serious error if we try to get a query at an index which doesn't exist or which has already been taken
404        if index >= guard.len() {
405            return Err(SubgraphBatchingError::ProcessingFailed(format!(
406                "tried to retriever sender for index: {index} which does not exist"
407            )));
408        }
409        let opt_sender = std::mem::take(&mut guard[index]);
410        if opt_sender.is_none() {
411            return Err(SubgraphBatchingError::ProcessingFailed(format!(
412                "tried to retriever sender for index: {index} which has already been taken"
413            )));
414        }
415        drop(guard);
416        Ok(BatchQuery {
417            index,
418            sender: Arc::new(Mutex::new(opt_sender)),
419            remaining: Arc::new(AtomicUsize::new(0)),
420            batch,
421        })
422    }
423}
424
425impl Drop for Batch {
426    fn drop(&mut self) {
427        // Failsafe: make sure that we kill the background task if the batch itself is dropped
428        self.spawn_handle.abort();
429    }
430}
431
432// Assemble a single batch request to a subgraph
433pub(crate) async fn assemble_batch(
434    requests: Vec<BatchQueryInfo>,
435) -> Result<
436    (
437        String,
438        Vec<(Context, SubgraphRequestId)>,
439        http::Request<RouterBody>,
440        Vec<oneshot::Sender<Result<SubgraphResponse, BoxError>>>,
441    ),
442    BoxError,
443> {
444    // Extract the collection of parts from the requests
445    let (txs, request_pairs): (Vec<_>, Vec<_>) = requests
446        .into_iter()
447        .map(|r| (r.sender, (r.request, r.gql_request)))
448        .unzip();
449    let (requests, gql_requests): (Vec<_>, Vec<_>) = request_pairs.into_iter().unzip();
450
451    // Construct the actual byte body of the batched request
452    let bytes = router::body::into_bytes(serde_json::to_string(&gql_requests)?).await?;
453
454    // Retain the various contexts for later use
455    let contexts = requests
456        .iter()
457        .map(|request| (request.context.clone(), request.id.clone()))
458        .collect::<Vec<(Context, SubgraphRequestId)>>();
459    // Grab the common info from the first request
460    let first_request = requests
461        .into_iter()
462        .next()
463        .ok_or(SubgraphBatchingError::RequestsIsEmpty)?
464        .subgraph_request;
465    let operation_name = first_request
466        .body()
467        .operation_name
468        .clone()
469        .unwrap_or_default();
470    let (parts, _) = first_request.into_parts();
471
472    // Generate the final request and pass it up
473    let request = http::Request::from_parts(parts, router::body::from_bytes(bytes));
474    Ok((operation_name, contexts, request, txs))
475}
476
477#[cfg(test)]
478mod tests {
479    use std::sync::Arc;
480    use std::time::Duration;
481
482    use http::header::ACCEPT;
483    use http::header::CONTENT_TYPE;
484    use tokio::sync::oneshot;
485    use tower::ServiceExt;
486    use wiremock::MockServer;
487    use wiremock::ResponseTemplate;
488    use wiremock::matchers;
489
490    use super::Batch;
491    use super::BatchQueryInfo;
492    use super::assemble_batch;
493    use crate::Configuration;
494    use crate::Context;
495    use crate::TestHarness;
496    use crate::graphql;
497    use crate::graphql::Request;
498    use crate::layers::ServiceExt as LayerExt;
499    use crate::services::SubgraphRequest;
500    use crate::services::SubgraphResponse;
501    use crate::services::http::HttpClientServiceFactory;
502    use crate::services::router;
503    use crate::services::router::body;
504    use crate::services::subgraph;
505    use crate::services::subgraph::SubgraphRequestId;
506    use crate::spec::QueryHash;
507
508    #[tokio::test(flavor = "multi_thread")]
509    async fn it_assembles_batch() {
510        // Assemble a list of requests for testing
511        let (receivers, requests): (Vec<_>, Vec<_>) = (0..2)
512            .map(|index| {
513                let (tx, rx) = oneshot::channel();
514                let gql_request = graphql::Request::fake_builder()
515                    .operation_name(format!("batch_test_{index}"))
516                    .query(format!("query batch_test {{ slot{index} }}"))
517                    .build();
518
519                (
520                    rx,
521                    BatchQueryInfo {
522                        request: SubgraphRequest::fake_builder()
523                            .subgraph_request(
524                                http::Request::builder().body(gql_request.clone()).unwrap(),
525                            )
526                            .subgraph_name(format!("slot{index}"))
527                            .build(),
528                        gql_request,
529                        sender: tx,
530                    },
531                )
532            })
533            .unzip();
534
535        // Create a vector of the input request context IDs for comparison
536        let input_context_ids = requests
537            .iter()
538            .map(|r| r.request.context.id.clone())
539            .collect::<Vec<String>>();
540        // Assemble them
541        let (op_name, contexts, request, txs) = assemble_batch(requests)
542            .await
543            .expect("it can assemble a batch");
544
545        let output_context_ids = contexts
546            .iter()
547            .map(|r| r.0.id.clone())
548            .collect::<Vec<String>>();
549        // Make sure all of our contexts are preserved during assembly
550        assert_eq!(input_context_ids, output_context_ids);
551
552        // Make sure that the name of the entire batch is that of the first
553        assert_eq!(op_name, "batch_test_0");
554
555        // We should see the aggregation of all of the requests
556        let actual: Vec<graphql::Request> = serde_json::from_str(
557            std::str::from_utf8(&router::body::into_bytes(request.into_body()).await.unwrap())
558                .unwrap(),
559        )
560        .unwrap();
561
562        let expected: Vec<_> = (0..2)
563            .map(|index| {
564                graphql::Request::fake_builder()
565                    .operation_name(format!("batch_test_{index}"))
566                    .query(format!("query batch_test {{ slot{index} }}"))
567                    .build()
568            })
569            .collect();
570        assert_eq!(actual, expected);
571
572        // We should also have all of the correct senders and they should be linked to the correct waiter
573        // Note: We reverse the senders since they should be in reverse order when assembled
574        assert_eq!(txs.len(), receivers.len());
575        for (index, (tx, rx)) in Iterator::zip(txs.into_iter(), receivers).enumerate() {
576            let data = serde_json_bytes::json!({
577                "data": {
578                    format!("slot{index}"): "valid"
579                }
580            });
581            let response = SubgraphResponse {
582                response: http::Response::builder()
583                    .body(graphql::Response::builder().data(data.clone()).build())
584                    .unwrap(),
585                context: Context::new(),
586                subgraph_name: String::default(),
587                id: SubgraphRequestId(String::new()),
588            };
589
590            tx.send(Ok(response)).unwrap();
591
592            // We want to make sure that we don't hang the test if we don't get the correct message
593            let received = tokio::time::timeout(Duration::from_millis(10), rx)
594                .await
595                .unwrap()
596                .unwrap()
597                .unwrap();
598
599            assert_eq!(received.response.into_body().data, Some(data));
600        }
601    }
602
603    #[tokio::test(flavor = "multi_thread")]
604    async fn it_rejects_index_out_of_bounds() {
605        let batch = Arc::new(Batch::spawn_handler(2));
606
607        assert!(Batch::query_for_index(batch.clone(), 2).is_err());
608    }
609
610    #[tokio::test(flavor = "multi_thread")]
611    async fn it_rejects_duplicated_index_get() {
612        let batch = Arc::new(Batch::spawn_handler(2));
613
614        assert!(Batch::query_for_index(batch.clone(), 0).is_ok());
615        assert!(Batch::query_for_index(batch.clone(), 0).is_err());
616    }
617
618    #[tokio::test(flavor = "multi_thread")]
619    async fn it_limits_the_number_of_cancelled_sends() {
620        let batch = Arc::new(Batch::spawn_handler(2));
621
622        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
623
624        assert!(
625            bq.set_query_hashes(vec![Arc::new(QueryHash::default())])
626                .await
627                .is_ok()
628        );
629        assert!(!bq.finished());
630        assert!(bq.signal_cancelled("why not?".to_string()).await.is_ok());
631        assert!(bq.finished());
632        assert!(
633            bq.signal_cancelled("only once though".to_string())
634                .await
635                .is_err()
636        );
637    }
638
639    #[tokio::test(flavor = "multi_thread")]
640    async fn it_limits_the_number_of_progressed_sends() {
641        let batch = Arc::new(Batch::spawn_handler(2));
642
643        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
644
645        let factory = HttpClientServiceFactory::from_config(
646            "testbatch",
647            &Configuration::default(),
648            crate::configuration::shared::Client::default(),
649        );
650        let request = SubgraphRequest::fake_builder()
651            .subgraph_request(
652                http::Request::builder()
653                    .body(graphql::Request::default())
654                    .unwrap(),
655            )
656            .subgraph_name("whatever".to_string())
657            .build();
658        assert!(
659            bq.set_query_hashes(vec![Arc::new(QueryHash::default())])
660                .await
661                .is_ok()
662        );
663        assert!(!bq.finished());
664        assert!(
665            bq.signal_progress(
666                factory.clone(),
667                request.clone(),
668                graphql::Request::default()
669            )
670            .await
671            .is_ok()
672        );
673        assert!(bq.finished());
674        assert!(
675            bq.signal_progress(factory, request, graphql::Request::default())
676                .await
677                .is_err()
678        );
679    }
680
681    #[tokio::test(flavor = "multi_thread")]
682    async fn it_limits_the_number_of_mixed_sends() {
683        let batch = Arc::new(Batch::spawn_handler(2));
684
685        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
686
687        let factory = HttpClientServiceFactory::from_config(
688            "testbatch",
689            &Configuration::default(),
690            crate::configuration::shared::Client::default(),
691        );
692        let request = SubgraphRequest::fake_builder()
693            .subgraph_request(
694                http::Request::builder()
695                    .body(graphql::Request::default())
696                    .unwrap(),
697            )
698            .subgraph_name("whatever".to_string())
699            .build();
700        assert!(
701            bq.set_query_hashes(vec![Arc::new(QueryHash::default())])
702                .await
703                .is_ok()
704        );
705        assert!(!bq.finished());
706        assert!(
707            bq.signal_progress(factory, request, graphql::Request::default())
708                .await
709                .is_ok()
710        );
711        assert!(bq.finished());
712        assert!(
713            bq.signal_cancelled("only once though".to_string())
714                .await
715                .is_err()
716        );
717    }
718
719    #[tokio::test(flavor = "multi_thread")]
720    async fn it_limits_the_number_of_mixed_sends_two_query_hashes() {
721        let batch = Arc::new(Batch::spawn_handler(2));
722
723        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
724
725        let factory = HttpClientServiceFactory::from_config(
726            "testbatch",
727            &Configuration::default(),
728            crate::configuration::shared::Client::default(),
729        );
730        let request = SubgraphRequest::fake_builder()
731            .subgraph_request(
732                http::Request::builder()
733                    .body(graphql::Request::default())
734                    .unwrap(),
735            )
736            .subgraph_name("whatever".to_string())
737            .build();
738        let qh = Arc::new(QueryHash::default());
739        assert!(bq.set_query_hashes(vec![qh.clone(), qh]).await.is_ok());
740        assert!(!bq.finished());
741        assert!(
742            bq.signal_progress(factory, request, graphql::Request::default())
743                .await
744                .is_ok()
745        );
746        assert!(!bq.finished());
747        assert!(
748            bq.signal_cancelled("only twice though".to_string())
749                .await
750                .is_ok()
751        );
752        assert!(bq.finished());
753        assert!(
754            bq.signal_cancelled("only twice though".to_string())
755                .await
756                .is_err()
757        );
758    }
759
760    fn expect_batch(request: &wiremock::Request) -> ResponseTemplate {
761        let requests: Vec<Request> = request.body_json().unwrap();
762
763        // Extract info about this operation
764        let (subgraph, count): (String, usize) = {
765            let re = regex::Regex::new(r"entry([AB])\(count: ?([0-9]+)\)").unwrap();
766            let captures = re.captures(requests[0].query.as_ref().unwrap()).unwrap();
767
768            (captures[1].to_string(), captures[2].parse().unwrap())
769        };
770
771        // We should have gotten `count` elements
772        assert_eq!(requests.len(), count);
773
774        // Each element should have be for the specified subgraph and should have a field selection
775        // of index.
776        // Note: The router appends info to the query, so we append it at this check
777        for (index, request) in requests.into_iter().enumerate() {
778            assert_eq!(
779                request.query,
780                Some(format!(
781                    "query op{index}__{}__0 {{ entry{}(count: {count}) {{ index }} }}",
782                    subgraph.to_lowercase(),
783                    subgraph
784                ))
785            );
786        }
787
788        ResponseTemplate::new(200).set_body_json(
789            (0..count)
790                .map(|index| {
791                    serde_json::json!({
792                        "data": {
793                            format!("entry{subgraph}"): {
794                                "index": index
795                            }
796                        }
797                    })
798                })
799                .collect::<Vec<_>>(),
800        )
801    }
802
803    #[tokio::test(flavor = "multi_thread")]
804    async fn it_matches_subgraph_request_ids_to_responses() {
805        // Create a wiremock server for each handler
806        let mock_server = MockServer::start().await;
807        mock_server
808            .register(
809                wiremock::Mock::given(matchers::method("POST"))
810                    .and(matchers::path("/a"))
811                    .respond_with(expect_batch)
812                    .expect(1),
813            )
814            .await;
815
816        let schema = include_str!("../tests/fixtures/batching/schema.graphql");
817        let service = TestHarness::builder()
818            .configuration_json(serde_json::json!({
819            "include_subgraph_errors": {
820                "all": true
821            },
822            "include_subgraph_errors": {
823                "all": true
824            },
825            "batching": {
826                "enabled": true,
827                "mode": "batch_http_link",
828                "subgraph": {
829                    "all": {
830                        "enabled": true
831                    }
832                }
833            },
834            "override_subgraph_url": {
835                "a": format!("{}/a", mock_server.uri())
836            }}))
837            .unwrap()
838            .schema(schema)
839            .subgraph_hook(move |_subgraph_name, service| {
840                service
841                    .map_future_with_request_data(
842                        |r: &subgraph::Request| r.id.clone(),
843                        |id, f| async move {
844                            let r: subgraph::ServiceResult = f.await;
845                            assert_eq!(id, r.as_ref().map(|r| r.id.clone()).unwrap());
846                            r
847                        },
848                    )
849                    .boxed()
850            })
851            .with_subgraph_network_requests()
852            .build_router()
853            .await
854            .unwrap();
855
856        let requests: Vec<_> = (0..3)
857            .map(|index| {
858                Request::fake_builder()
859                    .query(format!("query op{index}{{ entryA(count: 3) {{ index }} }}"))
860                    .build()
861            })
862            .collect();
863        let request = serde_json::to_value(requests).unwrap();
864
865        let context = Context::new();
866        let request = router::Request {
867            context,
868            router_request: http::Request::builder()
869                .method("POST")
870                .header(CONTENT_TYPE, "application/json")
871                .header(ACCEPT, "application/json")
872                .body(body::from_bytes(serde_json::to_vec(&request).unwrap()))
873                .unwrap(),
874        };
875
876        let response = service
877            .oneshot(request)
878            .await
879            .unwrap()
880            .next_response()
881            .await
882            .unwrap()
883            .unwrap();
884
885        let response: serde_json::Value = serde_json::from_slice(&response).unwrap();
886        insta::assert_json_snapshot!(response);
887    }
888}