hive_router_plan_executor/execution/
plan.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeSet, HashMap},
4};
5
6use bytes::{BufMut, Bytes};
7use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
8use hive_router_query_planner::planner::plan_nodes::{
9    ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode,
10    QueryPlan, SequenceNode,
11};
12use http::{HeaderMap, Method};
13use ntex_http::HeaderMap as NtexHeaderMap;
14use serde::Deserialize;
15use sonic_rs::ValueRef;
16
17use crate::{
18    context::ExecutionContext,
19    execution::{
20        error::PlanExecutionError, jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt,
21    },
22    executors::{
23        common::{HttpExecutionRequest, HttpExecutionResponse},
24        map::SubgraphExecutorMap,
25    },
26    headers::{
27        plan::HeaderRulesPlan,
28        request::modify_subgraph_request_headers,
29        response::{apply_subgraph_response_headers, modify_client_response_headers},
30    },
31    introspection::{
32        resolve::{resolve_introspection, IntrospectionContext},
33        schema::SchemaMetadata,
34    },
35    projection::{
36        plan::FieldProjectionPlan,
37        request::{project_requires, RequestProjectionContext},
38        response::project_by_operation,
39    },
40    response::{
41        graphql_error::{GraphQLError, GraphQLErrorExtensions, GraphQLErrorPath},
42        merge::deep_merge,
43        subgraph_response::SubgraphResponse,
44        value::Value,
45    },
46    utils::{
47        consts::{CLOSE_BRACKET, OPEN_BRACKET},
48        traverse::{traverse_and_callback, traverse_and_callback_mut},
49    },
50};
51
52pub struct OperationDetails<'a> {
53    pub name: Option<String>,
54    pub query: Cow<'a, str>,
55    pub kind: &'a str,
56}
57
58pub struct ClientRequestDetails<'a> {
59    pub method: Method,
60    pub url: http::Uri,
61    pub headers: &'a NtexHeaderMap,
62    pub operation: OperationDetails<'a>,
63}
64
65pub struct QueryPlanExecutionContext<'exec> {
66    pub query_plan: &'exec QueryPlan,
67    pub projection_plan: &'exec Vec<FieldProjectionPlan>,
68    pub headers_plan: &'exec HeaderRulesPlan,
69    pub variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
70    pub extensions: Option<HashMap<String, sonic_rs::Value>>,
71    pub client_request: ClientRequestDetails<'exec>,
72    pub introspection_context: &'exec IntrospectionContext<'exec, 'static>,
73    pub operation_type_name: &'exec str,
74    pub executors: &'exec SubgraphExecutorMap,
75    pub jwt_auth_forwarding: &'exec Option<JwtAuthForwardingPlan>,
76}
77
78pub struct PlanExecutionOutput {
79    pub body: Vec<u8>,
80    pub headers: HeaderMap,
81}
82
83pub async fn execute_query_plan<'exec>(
84    ctx: QueryPlanExecutionContext<'exec>,
85) -> Result<PlanExecutionOutput, PlanExecutionError> {
86    let init_value = if let Some(introspection_query) = ctx.introspection_context.query {
87        resolve_introspection(introspection_query, ctx.introspection_context)
88    } else {
89        Value::Null
90    };
91
92    let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value);
93    let executor = Executor::new(
94        ctx.variable_values,
95        ctx.executors,
96        ctx.introspection_context.metadata,
97        &ctx.client_request,
98        ctx.headers_plan,
99        ctx.jwt_auth_forwarding,
100        // Deduplicate subgraph requests only if the operation type is a query
101        ctx.operation_type_name == "Query",
102    );
103
104    if ctx.query_plan.node.is_some() {
105        executor
106            .execute(&mut exec_ctx, ctx.query_plan.node.as_ref())
107            .await?;
108    }
109
110    let mut response_headers = HeaderMap::new();
111    modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers)?;
112
113    let final_response = &exec_ctx.final_response;
114    let body = project_by_operation(
115        final_response,
116        exec_ctx.errors,
117        &ctx.extensions,
118        ctx.operation_type_name,
119        ctx.projection_plan,
120        ctx.variable_values,
121        exec_ctx.response_storage.estimate_final_response_size(),
122    )?;
123
124    Ok(PlanExecutionOutput {
125        body,
126        headers: response_headers,
127    })
128}
129
130pub struct Executor<'exec> {
131    variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
132    schema_metadata: &'exec SchemaMetadata,
133    executors: &'exec SubgraphExecutorMap,
134    client_request: &'exec ClientRequestDetails<'exec>,
135    headers_plan: &'exec HeaderRulesPlan,
136    jwt_forwarding_plan: &'exec Option<JwtAuthForwardingPlan>,
137    dedupe_subgraph_requests: bool,
138}
139
140struct ConcurrencyScope<'exec, T> {
141    jobs: FuturesUnordered<BoxFuture<'exec, T>>,
142}
143
144impl<'exec, T> ConcurrencyScope<'exec, T> {
145    fn new() -> Self {
146        Self {
147            jobs: FuturesUnordered::new(),
148        }
149    }
150
151    fn spawn(&mut self, future: BoxFuture<'exec, T>) {
152        self.jobs.push(future);
153    }
154
155    async fn join_all(mut self) -> Vec<T> {
156        let mut results = Vec::with_capacity(self.jobs.len());
157        while let Some(result) = self.jobs.next().await {
158            results.push(result);
159        }
160        results
161    }
162}
163
164struct SubgraphOutput {
165    body: Bytes,
166    headers: HeaderMap,
167}
168
169struct FetchJob {
170    fetch_node_id: i64,
171    subgraph_name: String,
172    response: SubgraphOutput,
173}
174
175struct FlattenFetchJob {
176    flatten_node_path: FlattenNodePath,
177    response: SubgraphOutput,
178    fetch_node_id: i64,
179    subgraph_name: String,
180    representation_hashes: Vec<u64>,
181    representation_hash_to_index: HashMap<u64, usize>,
182}
183
184enum ExecutionJob {
185    Fetch(FetchJob),
186    FlattenFetch(FlattenFetchJob),
187    None,
188}
189
190impl From<ExecutionJob> for SubgraphOutput {
191    fn from(value: ExecutionJob) -> Self {
192        match value {
193            ExecutionJob::Fetch(j) => Self {
194                body: j.response.body,
195                headers: j.response.headers,
196            },
197            ExecutionJob::FlattenFetch(j) => Self {
198                body: j.response.body,
199                headers: j.response.headers,
200            },
201            ExecutionJob::None => Self {
202                body: Bytes::new(),
203                headers: HeaderMap::new(),
204            },
205        }
206    }
207}
208
209impl From<HttpExecutionResponse> for SubgraphOutput {
210    fn from(res: HttpExecutionResponse) -> Self {
211        Self {
212            body: res.body,
213            headers: res.headers,
214        }
215    }
216}
217
218struct PreparedFlattenData {
219    representations: Vec<u8>,
220    representation_hashes: Vec<u64>,
221    representation_hash_to_index: HashMap<u64, usize>,
222}
223
224impl<'exec> Executor<'exec> {
225    pub fn new(
226        variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
227        executors: &'exec SubgraphExecutorMap,
228        schema_metadata: &'exec SchemaMetadata,
229        client_request: &'exec ClientRequestDetails<'exec>,
230        headers_plan: &'exec HeaderRulesPlan,
231        jwt_forwarding_plan: &'exec Option<JwtAuthForwardingPlan>,
232        dedupe_subgraph_requests: bool,
233    ) -> Self {
234        Executor {
235            variable_values,
236            executors,
237            schema_metadata,
238            client_request,
239            headers_plan,
240            dedupe_subgraph_requests,
241            jwt_forwarding_plan,
242        }
243    }
244
245    pub async fn execute(
246        &self,
247        ctx: &mut ExecutionContext<'exec>,
248        plan: Option<&PlanNode>,
249    ) -> Result<(), PlanExecutionError> {
250        match plan {
251            Some(PlanNode::Fetch(node)) => self.execute_fetch_wave(ctx, node).await,
252            Some(PlanNode::Parallel(node)) => self.execute_parallel_wave(ctx, node).await,
253            Some(PlanNode::Sequence(node)) => self.execute_sequence_wave(ctx, node).await,
254            // Plans produced by our Query Planner can only start with: Fetch, Sequence or Parallel.
255            // Any other node type at the root is not supported, do nothing
256            Some(_) => Ok(()),
257            // An empty plan is valid, just do nothing
258            None => Ok(()),
259        }
260    }
261
262    async fn execute_fetch_wave(
263        &self,
264        ctx: &mut ExecutionContext<'exec>,
265        node: &FetchNode,
266    ) -> Result<(), PlanExecutionError> {
267        match self.execute_fetch_node(node, None).await {
268            Ok(result) => self.process_job_result(ctx, result),
269            Err(err) => {
270                let extensions = GraphQLErrorExtensions::new_from_code_and_service_name(
271                    "PLAN_EXECUTION_ERROR",
272                    &node.service_name,
273                );
274                ctx.errors.push(GraphQLError::from_message_and_extensions(
275                    err.to_string(),
276                    extensions,
277                ));
278                Ok(())
279            }
280        }
281    }
282
283    async fn execute_sequence_wave(
284        &self,
285        ctx: &mut ExecutionContext<'exec>,
286        node: &SequenceNode,
287    ) -> Result<(), PlanExecutionError> {
288        for child in &node.nodes {
289            Box::pin(self.execute_plan_node(ctx, child)).await?;
290        }
291
292        Ok(())
293    }
294
295    async fn execute_parallel_wave(
296        &self,
297        ctx: &mut ExecutionContext<'exec>,
298        node: &ParallelNode,
299    ) -> Result<(), PlanExecutionError> {
300        let mut scope = ConcurrencyScope::new();
301
302        for child in &node.nodes {
303            let job_future = self.prepare_job_future(child, &ctx.final_response);
304            scope.spawn(job_future);
305        }
306
307        let results = scope.join_all().await;
308
309        for result in results {
310            match result {
311                Ok(job) => {
312                    self.process_job_result(ctx, job)?;
313                }
314                Err(err) => ctx.errors.push(GraphQLError::from_message_and_extensions(
315                    err.to_string(),
316                    GraphQLErrorExtensions::new_from_code("PLAN_EXECUTION_FAILED"),
317                )),
318            }
319        }
320
321        Ok(())
322    }
323
324    async fn execute_plan_node(
325        &self,
326        ctx: &mut ExecutionContext<'exec>,
327        node: &PlanNode,
328    ) -> Result<(), PlanExecutionError> {
329        match node {
330            PlanNode::Fetch(fetch_node) => match self.execute_fetch_node(fetch_node, None).await {
331                Ok(job) => {
332                    self.process_job_result(ctx, job)?;
333                }
334                Err(err) => ctx.errors.push(GraphQLError::from_message_and_extensions(
335                    err.to_string(),
336                    GraphQLErrorExtensions::new_from_code_and_service_name(
337                        "PLAN_EXECUTION_FAILED",
338                        &fetch_node.service_name,
339                    ),
340                )),
341            },
342            PlanNode::Parallel(parallel_node) => {
343                self.execute_parallel_wave(ctx, parallel_node).await?;
344            }
345            PlanNode::Flatten(flatten_node) => {
346                match self.prepare_flatten_data(&ctx.final_response, flatten_node) {
347                    Ok(Some(p)) => {
348                        match self
349                            .execute_flatten_fetch_node(
350                                flatten_node,
351                                Some(p.representations),
352                                Some(p.representation_hashes),
353                                Some(p.representation_hash_to_index),
354                            )
355                            .await
356                        {
357                            Ok(job) => {
358                                self.process_job_result(ctx, job)?;
359                            }
360                            Err(err) => {
361                                let service_name = service_name_from_plan_node(node);
362                                let extensions = service_name
363                                    .map(|name| {
364                                        GraphQLErrorExtensions::new_from_code_and_service_name(
365                                            "PLAN_EXECUTION_ERROR",
366                                            name,
367                                        )
368                                    })
369                                    .unwrap_or_else(|| {
370                                        GraphQLErrorExtensions::new_from_code(
371                                            "PLAN_EXECUTION_ERROR",
372                                        )
373                                    });
374                                ctx.errors.push(GraphQLError::from_message_and_extensions(
375                                    err.to_string(),
376                                    extensions,
377                                ));
378                            }
379                        }
380                    }
381                    Ok(None) => { /* do nothing */ }
382                    Err(err) => {
383                        let service_name = service_name_from_plan_node(node);
384                        let extensions = service_name
385                            .map(|name| {
386                                GraphQLErrorExtensions::new_from_code_and_service_name(
387                                    "PLAN_EXECUTION_ERROR",
388                                    name,
389                                )
390                            })
391                            .unwrap_or_else(|| {
392                                GraphQLErrorExtensions::new_from_code("PLAN_EXECUTION_ERROR")
393                            });
394
395                        ctx.errors.push(GraphQLError::from_message_and_extensions(
396                            err.to_string(),
397                            extensions,
398                        ));
399                    }
400                }
401            }
402            PlanNode::Sequence(sequence_node) => {
403                self.execute_sequence_wave(ctx, sequence_node).await?;
404            }
405            PlanNode::Condition(condition_node) => {
406                if let Some(node) =
407                    condition_node_by_variables(condition_node, self.variable_values)
408                {
409                    Box::pin(self.execute_plan_node(ctx, node)).await?;
410                }
411            }
412            // An unsupported plan node was found, do nothing.
413            _ => {}
414        }
415
416        Ok(())
417    }
418
419    fn prepare_job_future<'wave>(
420        &'wave self,
421        node: &'wave PlanNode,
422        final_response: &Value<'exec>,
423    ) -> BoxFuture<'wave, Result<ExecutionJob, PlanExecutionError>> {
424        match node {
425            PlanNode::Fetch(fetch_node) => Box::pin(self.execute_fetch_node(fetch_node, None)),
426            PlanNode::Flatten(flatten_node) => {
427                match self.prepare_flatten_data(final_response, flatten_node) {
428                    Ok(Some(p)) => Box::pin(self.execute_flatten_fetch_node(
429                        flatten_node,
430                        Some(p.representations),
431                        Some(p.representation_hashes),
432                        Some(p.representation_hash_to_index),
433                    )),
434                    Ok(None) => Box::pin(async { Ok(ExecutionJob::None) }),
435                    Err(e) => Box::pin(async move { Err(e) }),
436                }
437            }
438            PlanNode::Condition(node) => {
439                match condition_node_by_variables(node, self.variable_values) {
440                    Some(node) => Box::pin(self.prepare_job_future(node, final_response)), // This is already clean.
441                    None => Box::pin(async { Ok(ExecutionJob::None) }),
442                }
443            }
444            // Our Query Planner does not produce any other plan node types in ParallelNode
445            _ => Box::pin(async { Ok(ExecutionJob::None) }),
446        }
447    }
448
449    fn process_subgraph_response(
450        &self,
451        subgraph_name: &str,
452        ctx: &mut ExecutionContext<'exec>,
453        response_bytes: Bytes,
454        fetch_node_id: i64,
455    ) -> Option<(SubgraphResponse<'exec>, Option<&'exec Vec<FetchRewrite>>)> {
456        let idx = ctx.response_storage.add_response(response_bytes);
457        // SAFETY: The `bytes` are transmuted to the lifetime `'a` of the `ExecutionContext`.
458        // This is safe because the `response_storage` is part of the `ExecutionContext` (`ctx`)
459        // and will live as long as `'a`. The `Bytes` are stored in an `Arc`, so they won't be
460        // dropped until all references are gone. The `Value`s deserialized from this byte
461        // slice will borrow from it, and they are stored in `ctx.final_response`, which also
462        // lives for `'a`.
463        let bytes: &'exec [u8] =
464            unsafe { std::mem::transmute(ctx.response_storage.get_bytes(idx)) };
465
466        // SAFETY: The `output_rewrites` are transmuted to the lifetime `'a`. This is safe
467        // because `output_rewrites` is part of `OutputRewritesStorage` which is owned by
468        // `ExecutionContext` and lives for `'a`.
469        let output_rewrites: Option<&'exec Vec<FetchRewrite>> =
470            unsafe { std::mem::transmute(ctx.output_rewrites.get(fetch_node_id)) };
471
472        let mut deserializer = sonic_rs::Deserializer::from_slice(bytes);
473        let response = match SubgraphResponse::deserialize(&mut deserializer) {
474            Ok(response) => response,
475            Err(e) => {
476                let message = format!("Failed to deserialize subgraph response: {}", e);
477                let extensions = GraphQLErrorExtensions::new_from_code_and_service_name(
478                    "SUBGRAPH_RESPONSE_DESERIALIZATION_FAILED",
479                    subgraph_name,
480                );
481                let error = GraphQLError::from_message_and_extensions(message, extensions);
482
483                ctx.errors.push(error);
484                return None;
485            }
486        };
487
488        Some((response, output_rewrites))
489    }
490
491    fn process_job_result(
492        &self,
493        ctx: &mut ExecutionContext<'exec>,
494        job: ExecutionJob,
495    ) -> Result<(), PlanExecutionError> {
496        let _: () = match job {
497            ExecutionJob::Fetch(job) => {
498                apply_subgraph_response_headers(
499                    self.headers_plan,
500                    &job.subgraph_name,
501                    &job.response.headers,
502                    self.client_request,
503                    &mut ctx.response_headers_aggregator,
504                )?;
505
506                if let Some((mut response, output_rewrites)) = self.process_subgraph_response(
507                    job.subgraph_name.as_ref(),
508                    ctx,
509                    job.response.body,
510                    job.fetch_node_id,
511                ) {
512                    ctx.handle_errors(&job.subgraph_name, response.errors, None);
513                    if let Some(output_rewrites) = output_rewrites {
514                        for output_rewrite in output_rewrites {
515                            output_rewrite
516                                .rewrite(&self.schema_metadata.possible_types, &mut response.data);
517                        }
518                    }
519
520                    deep_merge(&mut ctx.final_response, response.data);
521                }
522            }
523            ExecutionJob::FlattenFetch(job) => {
524                apply_subgraph_response_headers(
525                    self.headers_plan,
526                    &job.subgraph_name,
527                    &job.response.headers,
528                    self.client_request,
529                    &mut ctx.response_headers_aggregator,
530                )?;
531
532                if let Some((mut response, output_rewrites)) = self.process_subgraph_response(
533                    &job.subgraph_name,
534                    ctx,
535                    job.response.body,
536                    job.fetch_node_id,
537                ) {
538                    if let Some(mut entities) = response.data.take_entities() {
539                        if let Some(output_rewrites) = output_rewrites {
540                            for output_rewrite in output_rewrites {
541                                for entity in &mut entities {
542                                    output_rewrite
543                                        .rewrite(&self.schema_metadata.possible_types, entity);
544                                }
545                            }
546                        }
547
548                        let mut index = 0;
549                        let normalized_path = job.flatten_node_path.as_slice();
550                        // If there is an error in the response, then collect the paths for normalizing the error
551                        let initial_error_path = response
552                            .errors
553                            .as_ref()
554                            .map(|_| GraphQLErrorPath::with_capacity(normalized_path.len() + 2));
555                        let mut entity_index_error_map = response
556                            .errors
557                            .as_ref()
558                            .map(|_| HashMap::with_capacity(entities.len()));
559                        traverse_and_callback_mut(
560                            &mut ctx.final_response,
561                            normalized_path,
562                            self.schema_metadata,
563                            initial_error_path,
564                            &mut |target, error_path| {
565                                let hash = job.representation_hashes[index];
566                                if let Some(entity_index) =
567                                    job.representation_hash_to_index.get(&hash)
568                                {
569                                    if let (Some(error_path), Some(entity_index_error_map)) =
570                                        (error_path, entity_index_error_map.as_mut())
571                                    {
572                                        let error_paths = entity_index_error_map
573                                            .entry(entity_index)
574                                            .or_insert_with(Vec::new);
575                                        error_paths.push(error_path);
576                                    }
577                                    if let Some(entity) = entities.get(*entity_index) {
578                                        // SAFETY: `new_val` is a clone of an entity that lives for `'a`.
579                                        // The transmute is to satisfy the compiler, but the lifetime
580                                        // is valid.
581                                        let new_val: Value<'_> =
582                                            unsafe { std::mem::transmute(entity.clone()) };
583                                        deep_merge(target, new_val);
584                                    }
585                                }
586                                index += 1;
587                            },
588                        );
589                        ctx.handle_errors(
590                            &job.subgraph_name,
591                            response.errors,
592                            entity_index_error_map,
593                        );
594                    }
595                }
596            }
597            ExecutionJob::None => {
598                // nothing to do
599            }
600        };
601        Ok(())
602    }
603
604    fn prepare_flatten_data(
605        &self,
606        final_response: &Value<'exec>,
607        flatten_node: &FlattenNode,
608    ) -> Result<Option<PreparedFlattenData>, PlanExecutionError> {
609        let fetch_node = match flatten_node.node.as_ref() {
610            PlanNode::Fetch(fetch_node) => fetch_node,
611            _ => return Ok(None),
612        };
613        let requires_nodes = match fetch_node.requires.as_ref() {
614            Some(nodes) => nodes,
615            None => return Ok(None),
616        };
617
618        let mut index = 0;
619        let normalized_path = flatten_node.path.as_slice();
620        let mut filtered_representations = Vec::new();
621        filtered_representations.put(OPEN_BRACKET);
622        let proj_ctx = RequestProjectionContext::new(&self.schema_metadata.possible_types);
623        let mut representation_hashes: Vec<u64> = Vec::new();
624        let mut filtered_representations_hashes: HashMap<u64, usize> = HashMap::new();
625        let arena = bumpalo::Bump::new();
626
627        traverse_and_callback(
628            final_response,
629            normalized_path,
630            self.schema_metadata,
631            &mut |entity| {
632                let hash = entity.to_hash(&requires_nodes.items, proj_ctx.possible_types);
633
634                if !entity.is_null() {
635                    representation_hashes.push(hash);
636                }
637
638                if filtered_representations_hashes.contains_key(&hash) {
639                    return Ok::<(), PlanExecutionError>(());
640                }
641
642                let entity = if let Some(input_rewrites) = &fetch_node.input_rewrites {
643                    let new_entity = arena.alloc(entity.clone());
644                    for input_rewrite in input_rewrites {
645                        input_rewrite.rewrite(&self.schema_metadata.possible_types, new_entity);
646                    }
647                    new_entity
648                } else {
649                    entity
650                };
651
652                let is_projected = project_requires(
653                    &proj_ctx,
654                    &requires_nodes.items,
655                    entity,
656                    &mut filtered_representations,
657                    filtered_representations_hashes.is_empty(),
658                    None,
659                )?;
660
661                if is_projected {
662                    filtered_representations_hashes.insert(hash, index);
663                }
664
665                index += 1;
666
667                Ok(())
668            },
669        )?;
670        filtered_representations.put(CLOSE_BRACKET);
671
672        if filtered_representations_hashes.is_empty() {
673            return Ok(None);
674        }
675
676        Ok(Some(PreparedFlattenData {
677            representations: filtered_representations,
678            representation_hashes,
679            representation_hash_to_index: filtered_representations_hashes,
680        }))
681    }
682
683    async fn execute_flatten_fetch_node(
684        &self,
685        node: &FlattenNode,
686        representations: Option<Vec<u8>>,
687        representation_hashes: Option<Vec<u64>>,
688        filtered_representations_hashes: Option<HashMap<u64, usize>>,
689    ) -> Result<ExecutionJob, PlanExecutionError> {
690        Ok(match node.node.as_ref() {
691            PlanNode::Fetch(fetch_node) => ExecutionJob::FlattenFetch(FlattenFetchJob {
692                flatten_node_path: node.path.clone(),
693                response: self
694                    .execute_fetch_node(fetch_node, representations)
695                    .await?
696                    .into(),
697                fetch_node_id: fetch_node.id,
698                subgraph_name: fetch_node.service_name.clone(),
699                representation_hashes: representation_hashes.unwrap_or_default(),
700                representation_hash_to_index: filtered_representations_hashes.unwrap_or_default(),
701            }),
702            _ => ExecutionJob::None,
703        })
704    }
705
706    async fn execute_fetch_node(
707        &self,
708        node: &FetchNode,
709        representations: Option<Vec<u8>>,
710    ) -> Result<ExecutionJob, PlanExecutionError> {
711        // TODO: We could optimize header map creation by caching them per service name
712        let mut headers_map = HeaderMap::new();
713        modify_subgraph_request_headers(
714            self.headers_plan,
715            &node.service_name,
716            self.client_request,
717            &mut headers_map,
718        )?;
719        let variable_refs =
720            select_fetch_variables(self.variable_values, node.variable_usages.as_ref());
721
722        let mut subgraph_request = HttpExecutionRequest {
723            query: node.operation.document_str.as_str(),
724            dedupe: self.dedupe_subgraph_requests,
725            operation_name: node.operation_name.as_deref(),
726            variables: variable_refs,
727            representations,
728            headers: headers_map,
729            extensions: None,
730        };
731
732        if let Some(jwt_forwarding_plan) = &self.jwt_forwarding_plan {
733            subgraph_request.add_request_extensions_field(
734                jwt_forwarding_plan.extension_field_name.clone(),
735                jwt_forwarding_plan.extension_field_value.clone(),
736            );
737        }
738
739        Ok(ExecutionJob::Fetch(FetchJob {
740            fetch_node_id: node.id,
741            subgraph_name: node.service_name.clone(),
742            response: self
743                .executors
744                .execute(&node.service_name, subgraph_request, self.client_request)
745                .await
746                .into(),
747        }))
748    }
749}
750
751fn service_name_from_plan_node(node: &PlanNode) -> Option<&str> {
752    match node {
753        PlanNode::Fetch(fetch_node) => Some(fetch_node.service_name.as_ref()),
754        PlanNode::Flatten(flatten_node) => service_name_from_plan_node(flatten_node.node.as_ref()),
755        PlanNode::Condition(condition_node) => condition_node
756            .if_clause
757            .as_deref()
758            .and_then(service_name_from_plan_node)
759            .or_else(|| {
760                condition_node
761                    .else_clause
762                    .as_deref()
763                    .and_then(service_name_from_plan_node)
764            }),
765        _ => None,
766    }
767}
768
769fn condition_node_by_variables<'a>(
770    condition_node: &'a ConditionNode,
771    variable_values: &'a Option<HashMap<String, sonic_rs::Value>>,
772) -> Option<&'a PlanNode> {
773    let vars = variable_values.as_ref()?;
774    let value = vars.get(&condition_node.condition)?;
775    let condition_met = matches!(value.as_ref(), ValueRef::Bool(true));
776
777    if condition_met {
778        condition_node.if_clause.as_deref()
779    } else {
780        condition_node.else_clause.as_deref()
781    }
782}
783
784fn select_fetch_variables<'a>(
785    variable_values: &'a Option<HashMap<String, sonic_rs::Value>>,
786    variable_usages: Option<&BTreeSet<String>>,
787) -> Option<HashMap<&'a str, &'a sonic_rs::Value>> {
788    let values = variable_values.as_ref()?;
789
790    variable_usages.map(|variable_usages| {
791        variable_usages
792            .iter()
793            .filter_map(|var_name| {
794                values
795                    .get_key_value(var_name.as_str())
796                    .map(|(key, value)| (key.as_str(), value))
797            })
798            .collect()
799    })
800}
801
802#[cfg(test)]
803mod tests {
804    use crate::{
805        context::ExecutionContext,
806        response::graphql_error::{GraphQLErrorExtensions, GraphQLErrorPath},
807    };
808
809    use super::select_fetch_variables;
810    use sonic_rs::Value;
811    use std::collections::{BTreeSet, HashMap};
812
813    fn value_from_number(n: i32) -> Value {
814        sonic_rs::from_str(&n.to_string()).unwrap()
815    }
816
817    #[test]
818    fn select_fetch_variables_only_used_variables() {
819        let mut variable_values_map = HashMap::new();
820        variable_values_map.insert("used".to_string(), value_from_number(1));
821        variable_values_map.insert("unused".to_string(), value_from_number(2));
822        let variable_values = Some(variable_values_map);
823
824        let mut usages = BTreeSet::new();
825        usages.insert("used".to_string());
826
827        let selected = select_fetch_variables(&variable_values, Some(&usages)).unwrap();
828
829        assert_eq!(selected.len(), 1);
830        assert!(selected.contains_key("used"));
831        assert!(!selected.contains_key("unused"));
832    }
833
834    #[test]
835    fn select_fetch_variables_ignores_missing_usage_entries() {
836        let mut variable_values_map = HashMap::new();
837        variable_values_map.insert("present".to_string(), value_from_number(3));
838        let variable_values = Some(variable_values_map);
839
840        let mut usages = BTreeSet::new();
841        usages.insert("present".to_string());
842        usages.insert("missing".to_string());
843
844        let selected = select_fetch_variables(&variable_values, Some(&usages)).unwrap();
845
846        assert_eq!(selected.len(), 1);
847        assert!(selected.contains_key("present"));
848        assert!(!selected.contains_key("missing"));
849    }
850
851    #[test]
852    fn select_fetch_variables_for_no_usage_entries() {
853        let mut variable_values_map = HashMap::new();
854        variable_values_map.insert("unused_1".to_string(), value_from_number(1));
855        variable_values_map.insert("unused_2".to_string(), value_from_number(2));
856
857        let variable_values = Some(variable_values_map);
858
859        let selected = select_fetch_variables(&variable_values, None);
860
861        assert!(selected.is_none());
862    }
863    #[test]
864    /**
865     * We have the same entity in two different paths ["a", 0] and ["b", 1],
866     * and the subgraph response has an error for this entity.
867     * So we should duplicate the error for both paths.
868     */
869    fn normalize_entity_errors_correctly() {
870        use crate::response::graphql_error::{GraphQLError, GraphQLErrorPathSegment};
871        use std::collections::HashMap;
872        let mut ctx = ExecutionContext::default();
873        let mut entity_index_error_map: HashMap<&usize, Vec<GraphQLErrorPath>> = HashMap::new();
874        entity_index_error_map.insert(
875            &0,
876            vec![
877                GraphQLErrorPath {
878                    segments: vec![
879                        GraphQLErrorPathSegment::String("a".to_string()),
880                        GraphQLErrorPathSegment::Index(0),
881                    ],
882                },
883                GraphQLErrorPath {
884                    segments: vec![
885                        GraphQLErrorPathSegment::String("b".to_string()),
886                        GraphQLErrorPathSegment::Index(1),
887                    ],
888                },
889            ],
890        );
891        let response_errors = vec![GraphQLError {
892            message: "Error 1".to_string(),
893            locations: None,
894            path: Some(GraphQLErrorPath {
895                segments: vec![
896                    GraphQLErrorPathSegment::String("_entities".to_string()),
897                    GraphQLErrorPathSegment::Index(0),
898                    GraphQLErrorPathSegment::String("field1".to_string()),
899                ],
900            }),
901            extensions: GraphQLErrorExtensions::default(),
902        }];
903        ctx.handle_errors(
904            "subgraph_a",
905            Some(response_errors),
906            Some(entity_index_error_map),
907        );
908        assert_eq!(ctx.errors.len(), 2);
909        assert_eq!(ctx.errors[0].message, "Error 1");
910        assert_eq!(
911            ctx.errors[0].path.as_ref().unwrap().segments,
912            vec![
913                GraphQLErrorPathSegment::String("a".to_string()),
914                GraphQLErrorPathSegment::Index(0),
915                GraphQLErrorPathSegment::String("field1".to_string())
916            ]
917        );
918        assert_eq!(ctx.errors[1].message, "Error 1");
919        assert_eq!(
920            ctx.errors[1].path.as_ref().unwrap().segments,
921            vec![
922                GraphQLErrorPathSegment::String("b".to_string()),
923                GraphQLErrorPathSegment::Index(1),
924                GraphQLErrorPathSegment::String("field1".to_string())
925            ]
926        );
927    }
928}