1use std::{
2 borrow::Cow,
3 collections::{BTreeSet, HashMap},
4};
5
6use bytes::{BufMut, Bytes};
7use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
8use hive_router_query_planner::planner::plan_nodes::{
9 ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode,
10 QueryPlan, SequenceNode,
11};
12use http::{HeaderMap, Method};
13use ntex_http::HeaderMap as NtexHeaderMap;
14use serde::Deserialize;
15use sonic_rs::ValueRef;
16
17use crate::{
18 context::ExecutionContext,
19 execution::{
20 error::PlanExecutionError, jwt_forward::JwtAuthForwardingPlan, rewrites::FetchRewriteExt,
21 },
22 executors::{
23 common::{HttpExecutionRequest, HttpExecutionResponse},
24 map::SubgraphExecutorMap,
25 },
26 headers::{
27 plan::HeaderRulesPlan,
28 request::modify_subgraph_request_headers,
29 response::{apply_subgraph_response_headers, modify_client_response_headers},
30 },
31 introspection::{
32 resolve::{resolve_introspection, IntrospectionContext},
33 schema::SchemaMetadata,
34 },
35 projection::{
36 plan::FieldProjectionPlan,
37 request::{project_requires, RequestProjectionContext},
38 response::project_by_operation,
39 },
40 response::{
41 graphql_error::{GraphQLError, GraphQLErrorExtensions, GraphQLErrorPath},
42 merge::deep_merge,
43 subgraph_response::SubgraphResponse,
44 value::Value,
45 },
46 utils::{
47 consts::{CLOSE_BRACKET, OPEN_BRACKET},
48 traverse::{traverse_and_callback, traverse_and_callback_mut},
49 },
50};
51
52pub struct OperationDetails<'a> {
53 pub name: Option<String>,
54 pub query: Cow<'a, str>,
55 pub kind: &'a str,
56}
57
58pub struct ClientRequestDetails<'a> {
59 pub method: Method,
60 pub url: http::Uri,
61 pub headers: &'a NtexHeaderMap,
62 pub operation: OperationDetails<'a>,
63}
64
65pub struct QueryPlanExecutionContext<'exec> {
66 pub query_plan: &'exec QueryPlan,
67 pub projection_plan: &'exec Vec<FieldProjectionPlan>,
68 pub headers_plan: &'exec HeaderRulesPlan,
69 pub variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
70 pub extensions: Option<HashMap<String, sonic_rs::Value>>,
71 pub client_request: ClientRequestDetails<'exec>,
72 pub introspection_context: &'exec IntrospectionContext<'exec, 'static>,
73 pub operation_type_name: &'exec str,
74 pub executors: &'exec SubgraphExecutorMap,
75 pub jwt_auth_forwarding: &'exec Option<JwtAuthForwardingPlan>,
76}
77
78pub struct PlanExecutionOutput {
79 pub body: Vec<u8>,
80 pub headers: HeaderMap,
81}
82
83pub async fn execute_query_plan<'exec>(
84 ctx: QueryPlanExecutionContext<'exec>,
85) -> Result<PlanExecutionOutput, PlanExecutionError> {
86 let init_value = if let Some(introspection_query) = ctx.introspection_context.query {
87 resolve_introspection(introspection_query, ctx.introspection_context)
88 } else {
89 Value::Null
90 };
91
92 let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value);
93 let executor = Executor::new(
94 ctx.variable_values,
95 ctx.executors,
96 ctx.introspection_context.metadata,
97 &ctx.client_request,
98 ctx.headers_plan,
99 ctx.jwt_auth_forwarding,
100 ctx.operation_type_name == "Query",
102 );
103
104 if ctx.query_plan.node.is_some() {
105 executor
106 .execute(&mut exec_ctx, ctx.query_plan.node.as_ref())
107 .await?;
108 }
109
110 let mut response_headers = HeaderMap::new();
111 modify_client_response_headers(exec_ctx.response_headers_aggregator, &mut response_headers)?;
112
113 let final_response = &exec_ctx.final_response;
114 let body = project_by_operation(
115 final_response,
116 exec_ctx.errors,
117 &ctx.extensions,
118 ctx.operation_type_name,
119 ctx.projection_plan,
120 ctx.variable_values,
121 exec_ctx.response_storage.estimate_final_response_size(),
122 )?;
123
124 Ok(PlanExecutionOutput {
125 body,
126 headers: response_headers,
127 })
128}
129
130pub struct Executor<'exec> {
131 variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
132 schema_metadata: &'exec SchemaMetadata,
133 executors: &'exec SubgraphExecutorMap,
134 client_request: &'exec ClientRequestDetails<'exec>,
135 headers_plan: &'exec HeaderRulesPlan,
136 jwt_forwarding_plan: &'exec Option<JwtAuthForwardingPlan>,
137 dedupe_subgraph_requests: bool,
138}
139
140struct ConcurrencyScope<'exec, T> {
141 jobs: FuturesUnordered<BoxFuture<'exec, T>>,
142}
143
144impl<'exec, T> ConcurrencyScope<'exec, T> {
145 fn new() -> Self {
146 Self {
147 jobs: FuturesUnordered::new(),
148 }
149 }
150
151 fn spawn(&mut self, future: BoxFuture<'exec, T>) {
152 self.jobs.push(future);
153 }
154
155 async fn join_all(mut self) -> Vec<T> {
156 let mut results = Vec::with_capacity(self.jobs.len());
157 while let Some(result) = self.jobs.next().await {
158 results.push(result);
159 }
160 results
161 }
162}
163
164struct SubgraphOutput {
165 body: Bytes,
166 headers: HeaderMap,
167}
168
169struct FetchJob {
170 fetch_node_id: i64,
171 subgraph_name: String,
172 response: SubgraphOutput,
173}
174
175struct FlattenFetchJob {
176 flatten_node_path: FlattenNodePath,
177 response: SubgraphOutput,
178 fetch_node_id: i64,
179 subgraph_name: String,
180 representation_hashes: Vec<u64>,
181 representation_hash_to_index: HashMap<u64, usize>,
182}
183
184enum ExecutionJob {
185 Fetch(FetchJob),
186 FlattenFetch(FlattenFetchJob),
187 None,
188}
189
190impl From<ExecutionJob> for SubgraphOutput {
191 fn from(value: ExecutionJob) -> Self {
192 match value {
193 ExecutionJob::Fetch(j) => Self {
194 body: j.response.body,
195 headers: j.response.headers,
196 },
197 ExecutionJob::FlattenFetch(j) => Self {
198 body: j.response.body,
199 headers: j.response.headers,
200 },
201 ExecutionJob::None => Self {
202 body: Bytes::new(),
203 headers: HeaderMap::new(),
204 },
205 }
206 }
207}
208
209impl From<HttpExecutionResponse> for SubgraphOutput {
210 fn from(res: HttpExecutionResponse) -> Self {
211 Self {
212 body: res.body,
213 headers: res.headers,
214 }
215 }
216}
217
218struct PreparedFlattenData {
219 representations: Vec<u8>,
220 representation_hashes: Vec<u64>,
221 representation_hash_to_index: HashMap<u64, usize>,
222}
223
224impl<'exec> Executor<'exec> {
225 pub fn new(
226 variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
227 executors: &'exec SubgraphExecutorMap,
228 schema_metadata: &'exec SchemaMetadata,
229 client_request: &'exec ClientRequestDetails<'exec>,
230 headers_plan: &'exec HeaderRulesPlan,
231 jwt_forwarding_plan: &'exec Option<JwtAuthForwardingPlan>,
232 dedupe_subgraph_requests: bool,
233 ) -> Self {
234 Executor {
235 variable_values,
236 executors,
237 schema_metadata,
238 client_request,
239 headers_plan,
240 dedupe_subgraph_requests,
241 jwt_forwarding_plan,
242 }
243 }
244
245 pub async fn execute(
246 &self,
247 ctx: &mut ExecutionContext<'exec>,
248 plan: Option<&PlanNode>,
249 ) -> Result<(), PlanExecutionError> {
250 match plan {
251 Some(PlanNode::Fetch(node)) => self.execute_fetch_wave(ctx, node).await,
252 Some(PlanNode::Parallel(node)) => self.execute_parallel_wave(ctx, node).await,
253 Some(PlanNode::Sequence(node)) => self.execute_sequence_wave(ctx, node).await,
254 Some(_) => Ok(()),
257 None => Ok(()),
259 }
260 }
261
262 async fn execute_fetch_wave(
263 &self,
264 ctx: &mut ExecutionContext<'exec>,
265 node: &FetchNode,
266 ) -> Result<(), PlanExecutionError> {
267 match self.execute_fetch_node(node, None).await {
268 Ok(result) => self.process_job_result(ctx, result),
269 Err(err) => {
270 let extensions = GraphQLErrorExtensions::new_from_code_and_service_name(
271 "PLAN_EXECUTION_ERROR",
272 &node.service_name,
273 );
274 ctx.errors.push(GraphQLError::from_message_and_extensions(
275 err.to_string(),
276 extensions,
277 ));
278 Ok(())
279 }
280 }
281 }
282
283 async fn execute_sequence_wave(
284 &self,
285 ctx: &mut ExecutionContext<'exec>,
286 node: &SequenceNode,
287 ) -> Result<(), PlanExecutionError> {
288 for child in &node.nodes {
289 Box::pin(self.execute_plan_node(ctx, child)).await?;
290 }
291
292 Ok(())
293 }
294
295 async fn execute_parallel_wave(
296 &self,
297 ctx: &mut ExecutionContext<'exec>,
298 node: &ParallelNode,
299 ) -> Result<(), PlanExecutionError> {
300 let mut scope = ConcurrencyScope::new();
301
302 for child in &node.nodes {
303 let job_future = self.prepare_job_future(child, &ctx.final_response);
304 scope.spawn(job_future);
305 }
306
307 let results = scope.join_all().await;
308
309 for result in results {
310 match result {
311 Ok(job) => {
312 self.process_job_result(ctx, job)?;
313 }
314 Err(err) => ctx.errors.push(GraphQLError::from_message_and_extensions(
315 err.to_string(),
316 GraphQLErrorExtensions::new_from_code("PLAN_EXECUTION_FAILED"),
317 )),
318 }
319 }
320
321 Ok(())
322 }
323
324 async fn execute_plan_node(
325 &self,
326 ctx: &mut ExecutionContext<'exec>,
327 node: &PlanNode,
328 ) -> Result<(), PlanExecutionError> {
329 match node {
330 PlanNode::Fetch(fetch_node) => match self.execute_fetch_node(fetch_node, None).await {
331 Ok(job) => {
332 self.process_job_result(ctx, job)?;
333 }
334 Err(err) => ctx.errors.push(GraphQLError::from_message_and_extensions(
335 err.to_string(),
336 GraphQLErrorExtensions::new_from_code_and_service_name(
337 "PLAN_EXECUTION_FAILED",
338 &fetch_node.service_name,
339 ),
340 )),
341 },
342 PlanNode::Parallel(parallel_node) => {
343 self.execute_parallel_wave(ctx, parallel_node).await?;
344 }
345 PlanNode::Flatten(flatten_node) => {
346 match self.prepare_flatten_data(&ctx.final_response, flatten_node) {
347 Ok(Some(p)) => {
348 match self
349 .execute_flatten_fetch_node(
350 flatten_node,
351 Some(p.representations),
352 Some(p.representation_hashes),
353 Some(p.representation_hash_to_index),
354 )
355 .await
356 {
357 Ok(job) => {
358 self.process_job_result(ctx, job)?;
359 }
360 Err(err) => {
361 let service_name = service_name_from_plan_node(node);
362 let extensions = service_name
363 .map(|name| {
364 GraphQLErrorExtensions::new_from_code_and_service_name(
365 "PLAN_EXECUTION_ERROR",
366 name,
367 )
368 })
369 .unwrap_or_else(|| {
370 GraphQLErrorExtensions::new_from_code(
371 "PLAN_EXECUTION_ERROR",
372 )
373 });
374 ctx.errors.push(GraphQLError::from_message_and_extensions(
375 err.to_string(),
376 extensions,
377 ));
378 }
379 }
380 }
381 Ok(None) => { }
382 Err(err) => {
383 let service_name = service_name_from_plan_node(node);
384 let extensions = service_name
385 .map(|name| {
386 GraphQLErrorExtensions::new_from_code_and_service_name(
387 "PLAN_EXECUTION_ERROR",
388 name,
389 )
390 })
391 .unwrap_or_else(|| {
392 GraphQLErrorExtensions::new_from_code("PLAN_EXECUTION_ERROR")
393 });
394
395 ctx.errors.push(GraphQLError::from_message_and_extensions(
396 err.to_string(),
397 extensions,
398 ));
399 }
400 }
401 }
402 PlanNode::Sequence(sequence_node) => {
403 self.execute_sequence_wave(ctx, sequence_node).await?;
404 }
405 PlanNode::Condition(condition_node) => {
406 if let Some(node) =
407 condition_node_by_variables(condition_node, self.variable_values)
408 {
409 Box::pin(self.execute_plan_node(ctx, node)).await?;
410 }
411 }
412 _ => {}
414 }
415
416 Ok(())
417 }
418
419 fn prepare_job_future<'wave>(
420 &'wave self,
421 node: &'wave PlanNode,
422 final_response: &Value<'exec>,
423 ) -> BoxFuture<'wave, Result<ExecutionJob, PlanExecutionError>> {
424 match node {
425 PlanNode::Fetch(fetch_node) => Box::pin(self.execute_fetch_node(fetch_node, None)),
426 PlanNode::Flatten(flatten_node) => {
427 match self.prepare_flatten_data(final_response, flatten_node) {
428 Ok(Some(p)) => Box::pin(self.execute_flatten_fetch_node(
429 flatten_node,
430 Some(p.representations),
431 Some(p.representation_hashes),
432 Some(p.representation_hash_to_index),
433 )),
434 Ok(None) => Box::pin(async { Ok(ExecutionJob::None) }),
435 Err(e) => Box::pin(async move { Err(e) }),
436 }
437 }
438 PlanNode::Condition(node) => {
439 match condition_node_by_variables(node, self.variable_values) {
440 Some(node) => Box::pin(self.prepare_job_future(node, final_response)), None => Box::pin(async { Ok(ExecutionJob::None) }),
442 }
443 }
444 _ => Box::pin(async { Ok(ExecutionJob::None) }),
446 }
447 }
448
449 fn process_subgraph_response(
450 &self,
451 subgraph_name: &str,
452 ctx: &mut ExecutionContext<'exec>,
453 response_bytes: Bytes,
454 fetch_node_id: i64,
455 ) -> Option<(SubgraphResponse<'exec>, Option<&'exec Vec<FetchRewrite>>)> {
456 let idx = ctx.response_storage.add_response(response_bytes);
457 let bytes: &'exec [u8] =
464 unsafe { std::mem::transmute(ctx.response_storage.get_bytes(idx)) };
465
466 let output_rewrites: Option<&'exec Vec<FetchRewrite>> =
470 unsafe { std::mem::transmute(ctx.output_rewrites.get(fetch_node_id)) };
471
472 let mut deserializer = sonic_rs::Deserializer::from_slice(bytes);
473 let response = match SubgraphResponse::deserialize(&mut deserializer) {
474 Ok(response) => response,
475 Err(e) => {
476 let message = format!("Failed to deserialize subgraph response: {}", e);
477 let extensions = GraphQLErrorExtensions::new_from_code_and_service_name(
478 "SUBGRAPH_RESPONSE_DESERIALIZATION_FAILED",
479 subgraph_name,
480 );
481 let error = GraphQLError::from_message_and_extensions(message, extensions);
482
483 ctx.errors.push(error);
484 return None;
485 }
486 };
487
488 Some((response, output_rewrites))
489 }
490
491 fn process_job_result(
492 &self,
493 ctx: &mut ExecutionContext<'exec>,
494 job: ExecutionJob,
495 ) -> Result<(), PlanExecutionError> {
496 let _: () = match job {
497 ExecutionJob::Fetch(job) => {
498 apply_subgraph_response_headers(
499 self.headers_plan,
500 &job.subgraph_name,
501 &job.response.headers,
502 self.client_request,
503 &mut ctx.response_headers_aggregator,
504 )?;
505
506 if let Some((mut response, output_rewrites)) = self.process_subgraph_response(
507 job.subgraph_name.as_ref(),
508 ctx,
509 job.response.body,
510 job.fetch_node_id,
511 ) {
512 ctx.handle_errors(&job.subgraph_name, response.errors, None);
513 if let Some(output_rewrites) = output_rewrites {
514 for output_rewrite in output_rewrites {
515 output_rewrite
516 .rewrite(&self.schema_metadata.possible_types, &mut response.data);
517 }
518 }
519
520 deep_merge(&mut ctx.final_response, response.data);
521 }
522 }
523 ExecutionJob::FlattenFetch(job) => {
524 apply_subgraph_response_headers(
525 self.headers_plan,
526 &job.subgraph_name,
527 &job.response.headers,
528 self.client_request,
529 &mut ctx.response_headers_aggregator,
530 )?;
531
532 if let Some((mut response, output_rewrites)) = self.process_subgraph_response(
533 &job.subgraph_name,
534 ctx,
535 job.response.body,
536 job.fetch_node_id,
537 ) {
538 if let Some(mut entities) = response.data.take_entities() {
539 if let Some(output_rewrites) = output_rewrites {
540 for output_rewrite in output_rewrites {
541 for entity in &mut entities {
542 output_rewrite
543 .rewrite(&self.schema_metadata.possible_types, entity);
544 }
545 }
546 }
547
548 let mut index = 0;
549 let normalized_path = job.flatten_node_path.as_slice();
550 let initial_error_path = response
552 .errors
553 .as_ref()
554 .map(|_| GraphQLErrorPath::with_capacity(normalized_path.len() + 2));
555 let mut entity_index_error_map = response
556 .errors
557 .as_ref()
558 .map(|_| HashMap::with_capacity(entities.len()));
559 traverse_and_callback_mut(
560 &mut ctx.final_response,
561 normalized_path,
562 self.schema_metadata,
563 initial_error_path,
564 &mut |target, error_path| {
565 let hash = job.representation_hashes[index];
566 if let Some(entity_index) =
567 job.representation_hash_to_index.get(&hash)
568 {
569 if let (Some(error_path), Some(entity_index_error_map)) =
570 (error_path, entity_index_error_map.as_mut())
571 {
572 let error_paths = entity_index_error_map
573 .entry(entity_index)
574 .or_insert_with(Vec::new);
575 error_paths.push(error_path);
576 }
577 if let Some(entity) = entities.get(*entity_index) {
578 let new_val: Value<'_> =
582 unsafe { std::mem::transmute(entity.clone()) };
583 deep_merge(target, new_val);
584 }
585 }
586 index += 1;
587 },
588 );
589 ctx.handle_errors(
590 &job.subgraph_name,
591 response.errors,
592 entity_index_error_map,
593 );
594 }
595 }
596 }
597 ExecutionJob::None => {
598 }
600 };
601 Ok(())
602 }
603
604 fn prepare_flatten_data(
605 &self,
606 final_response: &Value<'exec>,
607 flatten_node: &FlattenNode,
608 ) -> Result<Option<PreparedFlattenData>, PlanExecutionError> {
609 let fetch_node = match flatten_node.node.as_ref() {
610 PlanNode::Fetch(fetch_node) => fetch_node,
611 _ => return Ok(None),
612 };
613 let requires_nodes = match fetch_node.requires.as_ref() {
614 Some(nodes) => nodes,
615 None => return Ok(None),
616 };
617
618 let mut index = 0;
619 let normalized_path = flatten_node.path.as_slice();
620 let mut filtered_representations = Vec::new();
621 filtered_representations.put(OPEN_BRACKET);
622 let proj_ctx = RequestProjectionContext::new(&self.schema_metadata.possible_types);
623 let mut representation_hashes: Vec<u64> = Vec::new();
624 let mut filtered_representations_hashes: HashMap<u64, usize> = HashMap::new();
625 let arena = bumpalo::Bump::new();
626
627 traverse_and_callback(
628 final_response,
629 normalized_path,
630 self.schema_metadata,
631 &mut |entity| {
632 let hash = entity.to_hash(&requires_nodes.items, proj_ctx.possible_types);
633
634 if !entity.is_null() {
635 representation_hashes.push(hash);
636 }
637
638 if filtered_representations_hashes.contains_key(&hash) {
639 return Ok::<(), PlanExecutionError>(());
640 }
641
642 let entity = if let Some(input_rewrites) = &fetch_node.input_rewrites {
643 let new_entity = arena.alloc(entity.clone());
644 for input_rewrite in input_rewrites {
645 input_rewrite.rewrite(&self.schema_metadata.possible_types, new_entity);
646 }
647 new_entity
648 } else {
649 entity
650 };
651
652 let is_projected = project_requires(
653 &proj_ctx,
654 &requires_nodes.items,
655 entity,
656 &mut filtered_representations,
657 filtered_representations_hashes.is_empty(),
658 None,
659 )?;
660
661 if is_projected {
662 filtered_representations_hashes.insert(hash, index);
663 }
664
665 index += 1;
666
667 Ok(())
668 },
669 )?;
670 filtered_representations.put(CLOSE_BRACKET);
671
672 if filtered_representations_hashes.is_empty() {
673 return Ok(None);
674 }
675
676 Ok(Some(PreparedFlattenData {
677 representations: filtered_representations,
678 representation_hashes,
679 representation_hash_to_index: filtered_representations_hashes,
680 }))
681 }
682
683 async fn execute_flatten_fetch_node(
684 &self,
685 node: &FlattenNode,
686 representations: Option<Vec<u8>>,
687 representation_hashes: Option<Vec<u64>>,
688 filtered_representations_hashes: Option<HashMap<u64, usize>>,
689 ) -> Result<ExecutionJob, PlanExecutionError> {
690 Ok(match node.node.as_ref() {
691 PlanNode::Fetch(fetch_node) => ExecutionJob::FlattenFetch(FlattenFetchJob {
692 flatten_node_path: node.path.clone(),
693 response: self
694 .execute_fetch_node(fetch_node, representations)
695 .await?
696 .into(),
697 fetch_node_id: fetch_node.id,
698 subgraph_name: fetch_node.service_name.clone(),
699 representation_hashes: representation_hashes.unwrap_or_default(),
700 representation_hash_to_index: filtered_representations_hashes.unwrap_or_default(),
701 }),
702 _ => ExecutionJob::None,
703 })
704 }
705
706 async fn execute_fetch_node(
707 &self,
708 node: &FetchNode,
709 representations: Option<Vec<u8>>,
710 ) -> Result<ExecutionJob, PlanExecutionError> {
711 let mut headers_map = HeaderMap::new();
713 modify_subgraph_request_headers(
714 self.headers_plan,
715 &node.service_name,
716 self.client_request,
717 &mut headers_map,
718 )?;
719 let variable_refs =
720 select_fetch_variables(self.variable_values, node.variable_usages.as_ref());
721
722 let mut subgraph_request = HttpExecutionRequest {
723 query: node.operation.document_str.as_str(),
724 dedupe: self.dedupe_subgraph_requests,
725 operation_name: node.operation_name.as_deref(),
726 variables: variable_refs,
727 representations,
728 headers: headers_map,
729 extensions: None,
730 };
731
732 if let Some(jwt_forwarding_plan) = &self.jwt_forwarding_plan {
733 subgraph_request.add_request_extensions_field(
734 jwt_forwarding_plan.extension_field_name.clone(),
735 jwt_forwarding_plan.extension_field_value.clone(),
736 );
737 }
738
739 Ok(ExecutionJob::Fetch(FetchJob {
740 fetch_node_id: node.id,
741 subgraph_name: node.service_name.clone(),
742 response: self
743 .executors
744 .execute(&node.service_name, subgraph_request, self.client_request)
745 .await
746 .into(),
747 }))
748 }
749}
750
751fn service_name_from_plan_node(node: &PlanNode) -> Option<&str> {
752 match node {
753 PlanNode::Fetch(fetch_node) => Some(fetch_node.service_name.as_ref()),
754 PlanNode::Flatten(flatten_node) => service_name_from_plan_node(flatten_node.node.as_ref()),
755 PlanNode::Condition(condition_node) => condition_node
756 .if_clause
757 .as_deref()
758 .and_then(service_name_from_plan_node)
759 .or_else(|| {
760 condition_node
761 .else_clause
762 .as_deref()
763 .and_then(service_name_from_plan_node)
764 }),
765 _ => None,
766 }
767}
768
769fn condition_node_by_variables<'a>(
770 condition_node: &'a ConditionNode,
771 variable_values: &'a Option<HashMap<String, sonic_rs::Value>>,
772) -> Option<&'a PlanNode> {
773 let vars = variable_values.as_ref()?;
774 let value = vars.get(&condition_node.condition)?;
775 let condition_met = matches!(value.as_ref(), ValueRef::Bool(true));
776
777 if condition_met {
778 condition_node.if_clause.as_deref()
779 } else {
780 condition_node.else_clause.as_deref()
781 }
782}
783
784fn select_fetch_variables<'a>(
785 variable_values: &'a Option<HashMap<String, sonic_rs::Value>>,
786 variable_usages: Option<&BTreeSet<String>>,
787) -> Option<HashMap<&'a str, &'a sonic_rs::Value>> {
788 let values = variable_values.as_ref()?;
789
790 variable_usages.map(|variable_usages| {
791 variable_usages
792 .iter()
793 .filter_map(|var_name| {
794 values
795 .get_key_value(var_name.as_str())
796 .map(|(key, value)| (key.as_str(), value))
797 })
798 .collect()
799 })
800}
801
802#[cfg(test)]
803mod tests {
804 use crate::{
805 context::ExecutionContext,
806 response::graphql_error::{GraphQLErrorExtensions, GraphQLErrorPath},
807 };
808
809 use super::select_fetch_variables;
810 use sonic_rs::Value;
811 use std::collections::{BTreeSet, HashMap};
812
813 fn value_from_number(n: i32) -> Value {
814 sonic_rs::from_str(&n.to_string()).unwrap()
815 }
816
817 #[test]
818 fn select_fetch_variables_only_used_variables() {
819 let mut variable_values_map = HashMap::new();
820 variable_values_map.insert("used".to_string(), value_from_number(1));
821 variable_values_map.insert("unused".to_string(), value_from_number(2));
822 let variable_values = Some(variable_values_map);
823
824 let mut usages = BTreeSet::new();
825 usages.insert("used".to_string());
826
827 let selected = select_fetch_variables(&variable_values, Some(&usages)).unwrap();
828
829 assert_eq!(selected.len(), 1);
830 assert!(selected.contains_key("used"));
831 assert!(!selected.contains_key("unused"));
832 }
833
834 #[test]
835 fn select_fetch_variables_ignores_missing_usage_entries() {
836 let mut variable_values_map = HashMap::new();
837 variable_values_map.insert("present".to_string(), value_from_number(3));
838 let variable_values = Some(variable_values_map);
839
840 let mut usages = BTreeSet::new();
841 usages.insert("present".to_string());
842 usages.insert("missing".to_string());
843
844 let selected = select_fetch_variables(&variable_values, Some(&usages)).unwrap();
845
846 assert_eq!(selected.len(), 1);
847 assert!(selected.contains_key("present"));
848 assert!(!selected.contains_key("missing"));
849 }
850
851 #[test]
852 fn select_fetch_variables_for_no_usage_entries() {
853 let mut variable_values_map = HashMap::new();
854 variable_values_map.insert("unused_1".to_string(), value_from_number(1));
855 variable_values_map.insert("unused_2".to_string(), value_from_number(2));
856
857 let variable_values = Some(variable_values_map);
858
859 let selected = select_fetch_variables(&variable_values, None);
860
861 assert!(selected.is_none());
862 }
863 #[test]
864 fn normalize_entity_errors_correctly() {
870 use crate::response::graphql_error::{GraphQLError, GraphQLErrorPathSegment};
871 use std::collections::HashMap;
872 let mut ctx = ExecutionContext::default();
873 let mut entity_index_error_map: HashMap<&usize, Vec<GraphQLErrorPath>> = HashMap::new();
874 entity_index_error_map.insert(
875 &0,
876 vec![
877 GraphQLErrorPath {
878 segments: vec![
879 GraphQLErrorPathSegment::String("a".to_string()),
880 GraphQLErrorPathSegment::Index(0),
881 ],
882 },
883 GraphQLErrorPath {
884 segments: vec![
885 GraphQLErrorPathSegment::String("b".to_string()),
886 GraphQLErrorPathSegment::Index(1),
887 ],
888 },
889 ],
890 );
891 let response_errors = vec![GraphQLError {
892 message: "Error 1".to_string(),
893 locations: None,
894 path: Some(GraphQLErrorPath {
895 segments: vec![
896 GraphQLErrorPathSegment::String("_entities".to_string()),
897 GraphQLErrorPathSegment::Index(0),
898 GraphQLErrorPathSegment::String("field1".to_string()),
899 ],
900 }),
901 extensions: GraphQLErrorExtensions::default(),
902 }];
903 ctx.handle_errors(
904 "subgraph_a",
905 Some(response_errors),
906 Some(entity_index_error_map),
907 );
908 assert_eq!(ctx.errors.len(), 2);
909 assert_eq!(ctx.errors[0].message, "Error 1");
910 assert_eq!(
911 ctx.errors[0].path.as_ref().unwrap().segments,
912 vec![
913 GraphQLErrorPathSegment::String("a".to_string()),
914 GraphQLErrorPathSegment::Index(0),
915 GraphQLErrorPathSegment::String("field1".to_string())
916 ]
917 );
918 assert_eq!(ctx.errors[1].message, "Error 1");
919 assert_eq!(
920 ctx.errors[1].path.as_ref().unwrap().segments,
921 vec![
922 GraphQLErrorPathSegment::String("b".to_string()),
923 GraphQLErrorPathSegment::Index(1),
924 GraphQLErrorPathSegment::String("field1".to_string())
925 ]
926 );
927 }
928}